mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-26 19:22:00 +00:00
Merge branch 'rag_scoring_fn_2' into rag_scoring_fn_3
This commit is contained in:
commit
ae9291e0a0
127 changed files with 6386 additions and 493 deletions
|
|
@ -13,19 +13,64 @@ import secrets
|
|||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
Attachment,
|
||||
CodeInterpreterToolDefinition,
|
||||
FunctionCallToolDefinition,
|
||||
InferenceStep,
|
||||
MemoryRetrievalStep,
|
||||
MemoryToolDefinition,
|
||||
PhotogenToolDefinition,
|
||||
SearchToolDefinition,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
WolframAlphaToolDefinition,
|
||||
)
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
Message,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
|
|
|
|||
|
|
@ -9,15 +9,26 @@ import logging
|
|||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentTurnCreateRequest,
|
||||
Attachment,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@ import uuid
|
|||
from datetime import datetime
|
||||
|
||||
from typing import List, Optional
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@
|
|||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from llama_models.llama3.api import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
|
|
@ -16,7 +14,7 @@ from llama_stack.apis.agents import (
|
|||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ import logging
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,26 @@ from typing import AsyncIterator, List, Optional, Union
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
|
||||
from ..agents import (
|
||||
AGENT_INSTANCES_BY_ID,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from ..safety import ShieldRunnerMixin
|
||||
from .builtin import BaseTool
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.datasetio import * # noqa: F401, F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel): ...
|
||||
|
|
|
|||
|
|
@ -3,18 +3,19 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
import base64
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -32,11 +32,16 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
|||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
@ -44,12 +49,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
MetaReferenceInferenceConfig,
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Model
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
|
|
@ -27,9 +30,9 @@ class ModelRunner:
|
|||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, req: Any):
|
||||
if isinstance(req, ChatCompletionRequest):
|
||||
if isinstance(req, ChatCompletionRequestWithRawContent):
|
||||
return self.llama.chat_completion(req)
|
||||
elif isinstance(req, CompletionRequest):
|
||||
elif isinstance(req, CompletionRequestWithRawContent):
|
||||
return self.llama.completion(req)
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {type(req)}")
|
||||
|
|
@ -100,7 +103,7 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
|
|
@ -108,7 +111,7 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,10 @@ from pydantic import BaseModel, Field
|
|||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .generation import TokenResult
|
||||
|
||||
|
|
@ -79,7 +82,7 @@ class TaskRequest(BaseModel):
|
|||
type: Literal[ProcessingMessageName.task_request] = (
|
||||
ProcessingMessageName.task_request
|
||||
)
|
||||
task: Union[CompletionRequest, ChatCompletionRequest]
|
||||
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
|
|
@ -264,9 +267,6 @@ def launch_dist_group(
|
|||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
id = uuid.uuid4().hex
|
||||
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||
launch_config = LaunchConfig(
|
||||
|
|
@ -315,7 +315,7 @@ def start_model_parallel_process(
|
|||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||
|
||||
request_socket.send(encode_msg(ReadyRequest()))
|
||||
response = request_socket.recv()
|
||||
_response = request_socket.recv()
|
||||
log.info("Loaded model...")
|
||||
|
||||
return request_socket, process
|
||||
|
|
@ -349,7 +349,10 @@ class ModelParallelProcessGroup:
|
|||
self.started = False
|
||||
|
||||
def run_inference(
|
||||
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
||||
self,
|
||||
req: Union[
|
||||
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Optional
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
|
|
@ -18,9 +18,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
|
|
|
|||
|
|
@ -16,11 +16,14 @@ import faiss
|
|||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
|
||||
from llama_stack.apis.memory import (
|
||||
Chunk,
|
||||
Memory,
|
||||
MemoryBankDocument,
|
||||
QueryDocumentsResponse,
|
||||
)
|
||||
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
|
|||
|
|
@ -14,11 +14,10 @@ from enum import Enum
|
|||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.common.type_system import * # noqa
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.common.type_system import ParamType, StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
|
|
|
|||
|
|
@ -3,11 +3,26 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.schema_utils import webmethod
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
LoraFinetuningConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,27 +14,33 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
)
|
||||
from torch import nn
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
|
|
@ -43,11 +49,12 @@ from torchtune.modules.peft import (
|
|||
get_adapter_state_dict,
|
||||
get_lora_module_names,
|
||||
get_merged_lora_ckpt,
|
||||
load_dora_magnitudes,
|
||||
set_trainable_params,
|
||||
validate_missing_and_unexpected_for_lora,
|
||||
)
|
||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -110,6 +117,10 @@ class LoraFinetuningSingleDevice:
|
|||
self.checkpoint_dir = config.checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
|
||||
)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
|
@ -277,7 +288,6 @@ class LoraFinetuningSingleDevice:
|
|||
for m in model.modules():
|
||||
if hasattr(m, "initialize_dora_magnitude"):
|
||||
m.initialize_dora_magnitude()
|
||||
load_dora_magnitudes(model)
|
||||
if lora_weights_state_dict:
|
||||
lora_missing, lora_unexpected = model.load_state_dict(
|
||||
lora_weights_state_dict, strict=False
|
||||
|
|
|
|||
|
|
@ -7,8 +7,14 @@
|
|||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,24 @@ import re
|
|||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.llama3.api.datatypes import Role
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
Inference,
|
||||
Message,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
|
|
|||
|
|
@ -11,11 +11,16 @@ import torch
|
|||
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
|
|
|||
|
|
@ -3,14 +3,17 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringResult,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
|
||||
DataSchemaValidatorMixin,
|
||||
|
|
|
|||
|
|
@ -3,16 +3,8 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import (
|
||||
|
|
@ -27,6 +19,17 @@ from autoevals.ragas import (
|
|||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringResult,
|
||||
ScoringResultRow,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFn
|
||||
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator_mixin import (
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BraintrustScoringConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -17,6 +17,22 @@ from opentelemetry.sdk.trace import TracerProvider
|
|||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
MetricEvent,
|
||||
QueryCondition,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
SpanWithStatus,
|
||||
StructuredLogEvent,
|
||||
Telemetry,
|
||||
Trace,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
|
|
@ -27,10 +43,6 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
|
|||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
|
|
|
|||
|
|
@ -4,12 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
|
||||
class SampleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .brave_search import BraveSearchToolRuntimeImpl
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
||||
impl = BraveSearchToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: BraveSearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier != "brave_search":
|
||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Brave search tool group not supported")
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": args["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
content_items = "\n".join([str(result) for result in results])
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
)
|
||||
|
||||
def _clean_brave_response(self, search_response):
|
||||
clean_response = []
|
||||
if "mixed" in search_response:
|
||||
mixed_results = search_response["mixed"]
|
||||
for m in mixed_results["main"][: self.config.max_results]:
|
||||
r_type = m["type"]
|
||||
results = search_response[r_type]["results"]
|
||||
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
|
||||
clean_response.append(cleaned)
|
||||
|
||||
return clean_response
|
||||
|
||||
def _clean_result_by_type(self, r_type, results, idx=None):
|
||||
type_cleaners = {
|
||||
"web": (
|
||||
["type", "title", "url", "description", "date", "extra_snippets"],
|
||||
lambda x: x[idx],
|
||||
),
|
||||
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
|
||||
"infobox": (
|
||||
["type", "title", "url", "description", "long_desc"],
|
||||
lambda x: x[idx],
|
||||
),
|
||||
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
|
||||
"locations": (
|
||||
[
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"coordinates",
|
||||
"postal_address",
|
||||
"contact",
|
||||
"rating",
|
||||
"distance",
|
||||
"zoom_level",
|
||||
],
|
||||
lambda x: x,
|
||||
),
|
||||
"news": (["type", "title", "url", "description"], lambda x: x),
|
||||
}
|
||||
|
||||
if r_type not in type_cleaners:
|
||||
return ""
|
||||
|
||||
selected_keys, result_selector = type_cleaners[r_type]
|
||||
results = result_selector(results)
|
||||
|
||||
if isinstance(results, list):
|
||||
cleaned = [
|
||||
{k: v for k, v in item.items() if k in selected_keys}
|
||||
for item in results
|
||||
]
|
||||
else:
|
||||
cleaned = {k: v for k, v in results.items() if k in selected_keys}
|
||||
|
||||
return str(cleaned)
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BraveSearchToolConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Brave Search API Key",
|
||||
)
|
||||
max_results: int = Field(
|
||||
default=3,
|
||||
description="The maximum number of results to return",
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue