mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 12:47:37 +00:00
Merge branch 'main' into chroma
This commit is contained in:
commit
d3958fae4f
192 changed files with 7088 additions and 853 deletions
237
llama_stack/providers/remote/files/s3/README.md
Normal file
237
llama_stack/providers/remote/files/s3/README.md
Normal file
|
@ -0,0 +1,237 @@
|
|||
# S3 Files Provider
|
||||
|
||||
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
|
||||
|
||||
## Features
|
||||
|
||||
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
|
||||
- **Metadata Management**: Uses SQL database for efficient file metadata queries
|
||||
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
|
||||
- **Flexible Authentication**: Support for IAM roles and access keys
|
||||
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
api: files
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
bucket_name: my-llama-stack-files
|
||||
region: us-east-1
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ./s3_files_metadata.db
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```yaml
|
||||
api: files
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
bucket_name: my-llama-stack-files
|
||||
region: us-east-1
|
||||
aws_access_key_id: YOUR_ACCESS_KEY
|
||||
aws_secret_access_key: YOUR_SECRET_KEY
|
||||
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ./s3_files_metadata.db
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The configuration supports environment variable substitution:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: "${env.S3_BUCKET_NAME}"
|
||||
region: "${env.AWS_REGION:=us-east-1}"
|
||||
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
|
||||
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
|
||||
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
|
||||
```
|
||||
|
||||
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
|
||||
|
||||
## Authentication
|
||||
|
||||
### IAM Roles (Recommended)
|
||||
|
||||
For production deployments, use IAM roles:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
region: us-east-1
|
||||
# No credentials needed - will use IAM role
|
||||
```
|
||||
|
||||
### Access Keys
|
||||
|
||||
For development or specific use cases:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
region: us-east-1
|
||||
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
|
||||
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
|
||||
```
|
||||
|
||||
## S3 Bucket Setup
|
||||
|
||||
### Required Permissions
|
||||
|
||||
The S3 provider requires the following permissions:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::your-bucket-name",
|
||||
"arn:aws:s3:::your-bucket-name/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Automatic Bucket Creation
|
||||
|
||||
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
auto_create_bucket: true # Will create bucket if it doesn't exist
|
||||
region: us-east-1
|
||||
```
|
||||
|
||||
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject",
|
||||
"s3:ListBucket",
|
||||
"s3:CreateBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::your-bucket-name",
|
||||
"arn:aws:s3:::your-bucket-name/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Bucket Policy (Optional)
|
||||
|
||||
For additional security, you can add a bucket policy:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Sid": "LlamaStackAccess",
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||
},
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject"
|
||||
],
|
||||
"Resource": "arn:aws:s3:::your-bucket-name/*"
|
||||
},
|
||||
{
|
||||
"Sid": "LlamaStackBucketAccess",
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||
},
|
||||
"Action": [
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": "arn:aws:s3:::your-bucket-name"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Metadata Persistence
|
||||
|
||||
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
|
||||
|
||||
- File ID
|
||||
- Original filename
|
||||
- Purpose (assistants, batch, etc.)
|
||||
- File size in bytes
|
||||
- Created and expiration timestamps
|
||||
|
||||
### TTL and Cleanup
|
||||
|
||||
Files currently have a fixed long expiration time (100 years).
|
||||
|
||||
## Development and Testing
|
||||
|
||||
### Using MinIO
|
||||
|
||||
For self-hosted S3-compatible storage:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: test-bucket
|
||||
region: us-east-1
|
||||
endpoint_url: http://localhost:9000
|
||||
aws_access_key_id: minioadmin
|
||||
aws_secret_access_key: minioadmin
|
||||
```
|
||||
|
||||
## Monitoring and Logging
|
||||
|
||||
The provider logs important operations and errors. For production deployments, consider:
|
||||
|
||||
- CloudWatch monitoring for S3 operations
|
||||
- Custom metrics for file upload/download rates
|
||||
- Error rate monitoring
|
||||
- Performance metrics tracking
|
||||
|
||||
## Error Handling
|
||||
|
||||
The provider handles various error scenarios:
|
||||
|
||||
- S3 connectivity issues
|
||||
- Bucket access permissions
|
||||
- File not found errors
|
||||
- Metadata consistency checks
|
||||
|
||||
## Known Limitations
|
||||
|
||||
- Fixed long TTL (100 years) instead of configurable expiration
|
||||
- No server-side encryption enabled by default
|
||||
- No support for AWS session tokens
|
||||
- No S3 key prefix organization support
|
||||
- No multipart upload support (all files uploaded as single objects)
|
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
|
||||
from .files import S3FilesImpl
|
||||
|
||||
# TODO: authorization policies and user separation
|
||||
impl = S3FilesImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
42
llama_stack/providers/remote/files/s3/config.py
Normal file
42
llama_stack/providers/remote/files/s3/config.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
||||
class S3FilesImplConfig(BaseModel):
|
||||
"""Configuration for S3-based files provider."""
|
||||
|
||||
bucket_name: str = Field(description="S3 bucket name to store files")
|
||||
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
||||
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None, description="AWS secret access key (optional if using IAM roles)"
|
||||
)
|
||||
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
||||
auto_create_bucket: bool = Field(
|
||||
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
|
||||
"region": "${env.AWS_REGION:=us-east-1}",
|
||||
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
|
||||
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
|
||||
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
|
||||
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="s3_files_metadata.db",
|
||||
),
|
||||
}
|
272
llama_stack/providers/remote/files/s3/files.py
Normal file
272
llama_stack/providers/remote/files/s3/files.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
# TODO: provider data for S3 credentials
|
||||
|
||||
|
||||
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
||||
try:
|
||||
s3_config = {
|
||||
"region_name": config.region,
|
||||
}
|
||||
|
||||
# endpoint URL if specified (for MinIO, LocalStack, etc.)
|
||||
if config.endpoint_url:
|
||||
s3_config["endpoint_url"] = config.endpoint_url
|
||||
|
||||
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||
s3_config.update(
|
||||
{
|
||||
"aws_access_key_id": config.aws_access_key_id,
|
||||
"aws_secret_access_key": config.aws_secret_access_key,
|
||||
}
|
||||
)
|
||||
|
||||
return boto3.client("s3", **s3_config)
|
||||
|
||||
except (BotoCoreError, NoCredentialsError) as e:
|
||||
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
|
||||
|
||||
|
||||
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
|
||||
try:
|
||||
client.head_bucket(Bucket=config.bucket_name)
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "404":
|
||||
if not config.auto_create_bucket:
|
||||
raise RuntimeError(
|
||||
f"S3 bucket '{config.bucket_name}' does not exist. "
|
||||
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
|
||||
) from e
|
||||
try:
|
||||
# For us-east-1, we can't specify LocationConstraint
|
||||
if config.region == "us-east-1":
|
||||
client.create_bucket(Bucket=config.bucket_name)
|
||||
else:
|
||||
client.create_bucket(
|
||||
Bucket=config.bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": config.region},
|
||||
)
|
||||
except ClientError as create_error:
|
||||
raise RuntimeError(
|
||||
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
|
||||
) from create_error
|
||||
elif error_code == "403":
|
||||
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
|
||||
else:
|
||||
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
|
||||
|
||||
|
||||
class S3FilesImpl(Files):
|
||||
"""S3-based implementation of the Files API."""
|
||||
|
||||
# TODO: implement expiration, for now a silly offset
|
||||
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
|
||||
|
||||
def __init__(self, config: S3FilesImplConfig) -> None:
|
||||
self._config = config
|
||||
self._client: boto3.client | None = None
|
||||
self._sql_store: SqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = sqlstore_impl(self._config.metadata_store)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"filename": ColumnType.STRING,
|
||||
"purpose": ColumnType.STRING,
|
||||
"bytes": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"expires_at": ColumnType.INTEGER,
|
||||
# TODO: add s3_etag field for integrity checking
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> boto3.client:
|
||||
assert self._client is not None, "Provider not initialized"
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> SqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
created_at = int(time.time())
|
||||
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_files",
|
||||
{
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
self.client.put_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=file_id,
|
||||
Body=content,
|
||||
# TODO: enable server-side encryption
|
||||
)
|
||||
except ClientError as e:
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=file_id,
|
||||
filename=filename,
|
||||
purpose=purpose,
|
||||
bytes=file_size,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
# this purely defensive. it should not happen because the router also default to Order.desc.
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions = {}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [
|
||||
OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
for row in paginated_result.data
|
||||
]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=paginated_result.has_more,
|
||||
# empty string or None? spec says str, ref impl returns str | None, we go with spec
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
# TODO: can we stream this instead of loading it into memory
|
||||
content = response["Body"].read()
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||
raise RuntimeError(f"Failed to download file from S3: {e}") from e
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||
)
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
|
|
@ -3,15 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
|
|
|
@ -41,6 +41,11 @@ client.initialize()
|
|||
|
||||
### Create Completion
|
||||
|
||||
> Note on Completion API
|
||||
>
|
||||
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
|
||||
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
|
@ -76,7 +81,78 @@ response = client.inference.chat_completion(
|
|||
print(f"Response: {response.completion_message.content}")
|
||||
```
|
||||
|
||||
### Tool Calling Example ###
|
||||
```python
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
|
||||
tool_definition = ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get current weather information for a location",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
required=True,
|
||||
),
|
||||
"unit": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="Temperature unit (celsius or fahrenheit)",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
tool_response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||
tools=[tool_definition],
|
||||
)
|
||||
|
||||
print(f"Tool Response: {tool_response.completion_message.content}")
|
||||
if tool_response.completion_message.tool_calls:
|
||||
for tool_call in tool_response.completion_message.tool_calls:
|
||||
print(f"Tool Called: {tool_call.tool_name}")
|
||||
print(f"Arguments: {tool_call.arguments}")
|
||||
```
|
||||
|
||||
### Structured Output Example
|
||||
```python
|
||||
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||
|
||||
person_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"occupation": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "age", "occupation"],
|
||||
}
|
||||
|
||||
response_format = JsonSchemaResponseFormat(
|
||||
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||
)
|
||||
|
||||
structured_response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||
}
|
||||
],
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
print(f"Structured Response: {structured_response.completion_message.content}")
|
||||
```
|
||||
|
||||
### Create Embeddings
|
||||
> Note on OpenAI embeddings compatibility
|
||||
>
|
||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
||||
|
||||
```python
|
||||
response = client.inference.embeddings(
|
||||
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
|
|
|
@ -4,11 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai import APIConnectionError, BadRequestError
|
||||
from openai import NOT_GIVEN, APIConnectionError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -27,12 +26,16 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
@ -54,7 +57,7 @@ from .openai_utils import (
|
|||
)
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
|
@ -194,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
except BadRequestError as e:
|
||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||
# ->
|
||||
|
@ -210,6 +209,57 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
OpenAI-compatible embeddings for NVIDIA NIM.
|
||||
|
||||
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
|
||||
We default this to "query" to ensure requests succeed when using the
|
||||
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
|
||||
`task_type='document'`.
|
||||
"""
|
||||
extra_body: dict[str, object] = {"input_type": "query"}
|
||||
logger.warning(
|
||||
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
|
||||
"For passage embeddings, use the embeddings API with task_type='document'."
|
||||
)
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
|
||||
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
|
||||
user=user if user is not None else NOT_GIVEN,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
|
|
|
@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
|
@ -619,28 +619,6 @@ class OllamaInferenceAdapter(
|
|||
response.id = id
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for Ollama")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
|
|
|
@ -4,15 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference::openai")
|
||||
|
||||
|
||||
#
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference::tgi")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
|
|
@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
|
|
@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference::vllm")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
user=user,
|
||||
)
|
||||
return await self.client.chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for Ollama")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||
|
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.post_training import TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
|
||||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training::nvidia")
|
||||
|
||||
|
||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety::bedrock")
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
|
|
@ -4,20 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety::nvidia")
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
||||
|
||||
|
||||
class NeMoGuardrails:
|
||||
"""
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety::sambanova")
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
import asyncio
|
||||
import heapq
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
@ -35,7 +35,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import Reranker
|
|||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io::chroma")
|
||||
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io::milvus")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io::pgvector")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategy,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io::qdrant")
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
# KV store prefixes for vector databases
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import weaviate
|
||||
|
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
|
|||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue