Add a special header per-client call to parser provider data

This commit is contained in:
Ashwin Bharambe 2024-09-18 09:17:59 -07:00 committed by Xi Yan
parent a6be32bc3d
commit 32beecb20d
11 changed files with 955 additions and 104 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig
from .config import TogetherImplConfig, TogetherHeaderExtractor
async def get_adapter_impl(config: TogetherImplConfig, _deps):

View file

@ -4,9 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_models.schema_utils import json_schema_type
from llama_stack.distribution.request_headers import annotate_header
class TogetherHeaderExtractor(BaseModel):
api_key: annotate_header(
"X-LlamaStack-Together-ApiKey", str, "The API Key for the request"
)
@json_schema_type
class TogetherImplConfig(BaseModel):

View file

@ -63,6 +63,7 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
),
),
]

View file

@ -16,6 +16,7 @@ import httpx
import numpy as np
from numpy.typing import NDArray
from pypdf import PdfReader
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
@ -160,6 +161,8 @@ class BankWithIndex:
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
if not chunks:
continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
await self.index.add_chunks(chunks, embeddings)