mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
Merge branch 'main' into add-mcp-authentication-param
This commit is contained in:
commit
114ab693a5
40 changed files with 2827 additions and 1700 deletions
|
|
@ -74,7 +74,7 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
|
||||
async def register_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
@ -95,7 +95,7 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA, deprecated=True)
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark.
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class ListDatasetsResponse(BaseModel):
|
|||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA, deprecated=True)
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
|
|
@ -235,7 +235,7 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA, deprecated=True)
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -158,7 +158,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
|
|
@ -199,7 +199,9 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(
|
||||
route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
"""Unregister a scoring function.
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
|
@ -85,7 +85,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
"""Unregister a shield.
|
||||
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class ListToolDefsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
@ -167,7 +167,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
|
|
|||
|
|
@ -396,19 +396,19 @@ class VectorStoreListFilesResponse(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileContentsResponse(BaseModel):
|
||||
"""Response from retrieving the contents of a vector store file.
|
||||
class VectorStoreFileContentResponse(BaseModel):
|
||||
"""Represents the parsed content of a vector store file.
|
||||
|
||||
:param file_id: Unique identifier for the file
|
||||
:param filename: Name of the file
|
||||
:param attributes: Key-value attributes associated with the file
|
||||
:param content: List of content items from the file
|
||||
:param object: The object type, which is always `vector_store.file_content.page`
|
||||
:param data: Parsed content of the file
|
||||
:param has_more: Indicates if there are more content pages to fetch
|
||||
:param next_page: The token for the next page, if any
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
attributes: dict[str, Any]
|
||||
content: list[VectorStoreContent]
|
||||
object: Literal["vector_store.file_content.page"] = "vector_store.file_content.page"
|
||||
data: list[VectorStoreContent]
|
||||
has_more: bool
|
||||
next_page: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -732,12 +732,12 @@ class VectorIO(Protocol):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
) -> VectorStoreFileContentResponse:
|
||||
"""Retrieves the contents of a vector store file.
|
||||
|
||||
:param vector_store_id: The ID of the vector store containing the file to retrieve.
|
||||
:param file_id: The ID of the file to retrieve.
|
||||
:returns: A list of InterleavedContent representing the file contents.
|
||||
:returns: A VectorStoreFileContentResponse representing the file contents.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategyStaticConfig,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFilesListInBatchResponse,
|
||||
|
|
@ -338,7 +338,7 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
) -> VectorStoreFileContentResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.vector_io.vector_io import (
|
|||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
|
|
@ -195,7 +195,7 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
) -> VectorStoreFileContentResponse:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
|
|
|
|||
7
src/llama_stack/distributions/oci/__init__.py
Normal file
7
src/llama_stack/distributions/oci/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .oci import get_distribution_template # noqa: F401
|
||||
35
src/llama_stack/distributions/oci/build.yaml
Normal file
35
src/llama_stack/distributions/oci/build.yaml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM
|
||||
inference with scalable cloud services
|
||||
providers:
|
||||
inference:
|
||||
- provider_type: remote::oci
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
safety:
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
140
src/llama_stack/distributions/oci/doc_template.md
Normal file
140
src/llama_stack/distributions/oci/doc_template.md
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# OCI Distribution
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if default_models %}
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## Prerequisites
|
||||
### Oracle Cloud Infrastructure Setup
|
||||
|
||||
Before using the OCI Generative AI distribution, ensure you have:
|
||||
|
||||
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
|
||||
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
|
||||
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
|
||||
4. **Authentication**: Configure authentication using either:
|
||||
- **Instance Principal** (recommended for cloud-hosted deployments)
|
||||
- **API Key** (for on-premises or development environments)
|
||||
|
||||
### Authentication Methods
|
||||
|
||||
#### Instance Principal Authentication (Recommended)
|
||||
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
|
||||
|
||||
Requirements:
|
||||
- Instance must be running in an Oracle Cloud Infrastructure compartment
|
||||
- Instance must have appropriate IAM policies to access Generative AI services
|
||||
|
||||
#### API Key Authentication
|
||||
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
|
||||
|
||||
### Required IAM Policies
|
||||
|
||||
Ensure your OCI user or instance has the following policy statements:
|
||||
|
||||
```
|
||||
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
```
|
||||
|
||||
## Supported Services
|
||||
|
||||
### Inference: OCI Generative AI
|
||||
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
|
||||
|
||||
- **Chat Completions**: Conversational AI with context awareness
|
||||
- **Text Generation**: Complete prompts and generate text content
|
||||
|
||||
#### Available Models
|
||||
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
|
||||
|
||||
### Safety: Llama Guard
|
||||
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
|
||||
- Content filtering and moderation
|
||||
- Policy compliance checking
|
||||
- Harmful content detection
|
||||
|
||||
### Vector Storage: Multiple Options
|
||||
The distribution supports several vector storage providers:
|
||||
- **FAISS**: Local in-memory vector search
|
||||
- **ChromaDB**: Distributed vector database
|
||||
- **PGVector**: PostgreSQL with vector extensions
|
||||
|
||||
### Additional Services
|
||||
- **Dataset I/O**: Local filesystem and Hugging Face integration
|
||||
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
|
||||
- **Evaluation**: Meta reference evaluation framework
|
||||
|
||||
## Running Llama Stack with OCI
|
||||
|
||||
You can run the OCI distribution via Docker or local virtual environment.
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
|
||||
```
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### Using Instance Principal (Recommended for Production)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=instance_principal
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
|
||||
```
|
||||
|
||||
#### Using API Key Authentication (Development)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=config_file
|
||||
export OCI_CONFIG_FILE_PATH=~/.oci/config
|
||||
export OCI_CLI_PROFILE=DEFAULT
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
|
||||
```
|
||||
|
||||
## Regional Endpoints
|
||||
|
||||
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
|
||||
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
|
||||
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
|
||||
3. **Permission Denied**: Check compartment permissions and Generative AI service access
|
||||
4. **Region Unavailable**: Verify the specified region supports Generative AI services
|
||||
|
||||
### Getting Help
|
||||
|
||||
For additional support:
|
||||
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
|
||||
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)
|
||||
108
src/llama_stack/distributions/oci/oci.py
Normal file
108
src/llama_stack/distributions/oci/oci.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||
|
||||
|
||||
def get_distribution_template(name: str = "oci") -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type="remote::oci")],
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="oci",
|
||||
provider_type="remote::oci",
|
||||
config=OCIConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
vector_io_provider = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"OCI_AUTH_TYPE": (
|
||||
"instance_principal",
|
||||
"OCI authentication type (instance_principal or config_file)",
|
||||
),
|
||||
"OCI_REGION": (
|
||||
"",
|
||||
"OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)",
|
||||
),
|
||||
"OCI_COMPARTMENT_OCID": (
|
||||
"",
|
||||
"OCI compartment ID for the Generative AI service",
|
||||
),
|
||||
"OCI_CONFIG_FILE_PATH": (
|
||||
"~/.oci/config",
|
||||
"OCI config file path (required if OCI_AUTH_TYPE is config_file)",
|
||||
),
|
||||
"OCI_CLI_PROFILE": (
|
||||
"DEFAULT",
|
||||
"OCI CLI profile name to use from config file",
|
||||
),
|
||||
},
|
||||
)
|
||||
136
src/llama_stack/distributions/oci/run.yaml
Normal file
136
src/llama_stack/distributions/oci/run.yaml
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
version: 2
|
||||
image_name: oci
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: oci
|
||||
provider_type: remote::oci
|
||||
config:
|
||||
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
|
||||
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
|
||||
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
|
||||
oci_region: ${env.OCI_REGION:=us-ashburn-1}
|
||||
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
@ -223,7 +223,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
assert self.kvstore is not None
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
|
@ -239,7 +240,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
return [i.vector_store for i in self.cache.values()]
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
assert self.kvstore is not None
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before unregistering vector stores.")
|
||||
|
||||
if vector_store_id not in self.cache:
|
||||
return
|
||||
|
|
@ -248,6 +250,27 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoco
|
|||
del self.cache[vector_store_id]
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if not vector_store_data:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_store_id] = index
|
||||
return index
|
||||
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = self.cache.get(vector_store_id)
|
||||
if index is None:
|
||||
|
|
|
|||
|
|
@ -412,6 +412,14 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
return [v.vector_store for v in self.cache.values()]
|
||||
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
|
||||
|
||||
# Save to kvstore for persistence
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
# Create and cache the index
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
|
||||
)
|
||||
|
|
@ -421,13 +429,16 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
|
|||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
# Try to load from kvstore
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if not vector_store_data:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=SQLiteVecIndex(
|
||||
|
|
|
|||
|
|
@ -297,6 +297,20 @@ Available Models:
|
|||
Azure OpenAI inference provider for accessing GPT models and other Azure services.
|
||||
Provider documentation
|
||||
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||
""",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="remote::oci",
|
||||
adapter_type="oci",
|
||||
pip_packages=["oci"],
|
||||
module="llama_stack.providers.remote.inference.oci",
|
||||
config_class="llama_stack.providers.remote.inference.oci.config.OCIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.oci.config.OCIProviderDataValidator",
|
||||
description="""
|
||||
Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models.
|
||||
Provider documentation
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
17
src/llama_stack/providers/remote/inference/oci/__init__.py
Normal file
17
src/llama_stack/providers/remote/inference/oci/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import OCIConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider:
|
||||
from .oci import OCIInferenceAdapter
|
||||
|
||||
adapter = OCIInferenceAdapter(config=config)
|
||||
await adapter.initialize()
|
||||
return adapter
|
||||
79
src/llama_stack/providers/remote/inference/oci/auth.py
Normal file
79
src/llama_stack/providers/remote/inference/oci/auth.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, override
|
||||
|
||||
import httpx
|
||||
import oci
|
||||
import requests
|
||||
from oci.config import DEFAULT_LOCATION, DEFAULT_PROFILE
|
||||
|
||||
OciAuthSigner = type[oci.signer.AbstractBaseSigner]
|
||||
|
||||
|
||||
class HttpxOciAuth(httpx.Auth):
|
||||
"""
|
||||
Custom HTTPX authentication class that implements OCI request signing.
|
||||
|
||||
This class handles the authentication flow for HTTPX requests by signing them
|
||||
using the OCI Signer, which adds the necessary authentication headers for
|
||||
OCI API calls.
|
||||
|
||||
Attributes:
|
||||
signer (oci.signer.Signer): The OCI signer instance used for request signing
|
||||
"""
|
||||
|
||||
def __init__(self, signer: OciAuthSigner):
|
||||
self.signer = signer
|
||||
|
||||
@override
|
||||
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
|
||||
# Read the request content to handle streaming requests properly
|
||||
try:
|
||||
content = request.content
|
||||
except httpx.RequestNotRead:
|
||||
# For streaming requests, we need to read the content first
|
||||
content = request.read()
|
||||
|
||||
req = requests.Request(
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
headers=dict(request.headers),
|
||||
data=content,
|
||||
)
|
||||
prepared_request = req.prepare()
|
||||
|
||||
# Sign the request using the OCI Signer
|
||||
self.signer.do_request_sign(prepared_request) # type: ignore
|
||||
|
||||
# Update the original HTTPX request with the signed headers
|
||||
request.headers.update(prepared_request.headers)
|
||||
|
||||
yield request
|
||||
|
||||
|
||||
class OciInstancePrincipalAuth(HttpxOciAuth):
|
||||
def __init__(self, **kwargs: Mapping[str, Any]):
|
||||
self.signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner(**kwargs)
|
||||
|
||||
|
||||
class OciUserPrincipalAuth(HttpxOciAuth):
|
||||
def __init__(self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE):
|
||||
config = oci.config.from_file(config_file, profile_name)
|
||||
oci.config.validate_config(config) # type: ignore
|
||||
key_content = ""
|
||||
with open(config["key_file"]) as f:
|
||||
key_content = f.read()
|
||||
|
||||
self.signer = oci.signer.Signer(
|
||||
tenancy=config["tenancy"],
|
||||
user=config["user"],
|
||||
fingerprint=config["fingerprint"],
|
||||
private_key_file_location=config.get("key_file"),
|
||||
pass_phrase="none", # type: ignore
|
||||
private_key_content=key_content,
|
||||
)
|
||||
75
src/llama_stack/providers/remote/inference/oci/config.py
Normal file
75
src/llama_stack/providers/remote/inference/oci/config.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OCIProviderDataValidator(BaseModel):
|
||||
oci_auth_type: str = Field(
|
||||
description="OCI authentication type (must be one of: instance_principal, config_file)",
|
||||
)
|
||||
oci_region: str = Field(
|
||||
description="OCI region (e.g., us-ashburn-1)",
|
||||
)
|
||||
oci_compartment_id: str = Field(
|
||||
description="OCI compartment ID for the Generative AI service",
|
||||
)
|
||||
oci_config_file_path: str | None = Field(
|
||||
default="~/.oci/config",
|
||||
description="OCI config file path (required if oci_auth_type is config_file)",
|
||||
)
|
||||
oci_config_profile: str | None = Field(
|
||||
default="DEFAULT",
|
||||
description="OCI config profile (required if oci_auth_type is config_file)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OCIConfig(RemoteInferenceProviderConfig):
|
||||
oci_auth_type: str = Field(
|
||||
description="OCI authentication type (must be one of: instance_principal, config_file)",
|
||||
default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"),
|
||||
)
|
||||
oci_region: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_REGION", "us-ashburn-1"),
|
||||
description="OCI region (e.g., us-ashburn-1)",
|
||||
)
|
||||
oci_compartment_id: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID", ""),
|
||||
description="OCI compartment ID for the Generative AI service",
|
||||
)
|
||||
oci_config_file_path: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"),
|
||||
description="OCI config file path (required if oci_auth_type is config_file)",
|
||||
)
|
||||
oci_config_profile: str = Field(
|
||||
default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"),
|
||||
description="OCI config profile (required if oci_auth_type is config_file)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}",
|
||||
oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}",
|
||||
oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}",
|
||||
oci_region: str = "${env.OCI_REGION:=us-ashburn-1}",
|
||||
oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"oci_auth_type": oci_auth_type,
|
||||
"oci_config_file_path": oci_config_file_path,
|
||||
"oci_config_profile": oci_config_profile,
|
||||
"oci_region": oci_region,
|
||||
"oci_compartment_id": oci_compartment_id,
|
||||
}
|
||||
140
src/llama_stack/providers/remote/inference/oci/oci.py
Normal file
140
src/llama_stack/providers/remote/inference/oci/oci.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import oci
|
||||
from oci.generative_ai.generative_ai_client import GenerativeAiClient
|
||||
from oci.generative_ai.models import ModelCollection
|
||||
from openai._base_client import DefaultAsyncHttpxClient
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.oci.auth import OciInstancePrincipalAuth, OciUserPrincipalAuth
|
||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::oci")
|
||||
|
||||
OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal"
|
||||
OCI_AUTH_TYPE_CONFIG_FILE = "config_file"
|
||||
VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE]
|
||||
DEFAULT_OCI_REGION = "us-ashburn-1"
|
||||
|
||||
MODEL_CAPABILITIES = ["TEXT_GENERATION", "TEXT_SUMMARIZATION", "TEXT_EMBEDDINGS", "CHAT"]
|
||||
|
||||
|
||||
class OCIInferenceAdapter(OpenAIMixin):
|
||||
config: OCIConfig
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize and validate OCI configuration."""
|
||||
if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid OCI authentication type: {self.config.oci_auth_type}."
|
||||
f"Valid types are one of: {VALID_OCI_AUTH_TYPES}"
|
||||
)
|
||||
|
||||
if not self.config.oci_compartment_id:
|
||||
raise ValueError("OCI_COMPARTMENT_OCID is a required parameter. Either set in env variable or config.")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
region = self.config.oci_region or DEFAULT_OCI_REGION
|
||||
return f"https://inference.generativeai.{region}.oci.oraclecloud.com/20231130/actions/v1"
|
||||
|
||||
def get_api_key(self) -> str | None:
|
||||
# OCI doesn't use API keys, it uses request signing
|
||||
return "<NOTUSED>"
|
||||
|
||||
def get_extra_client_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra parameters for the AsyncOpenAI client, including OCI-specific auth and headers.
|
||||
"""
|
||||
auth = self._get_auth()
|
||||
compartment_id = self.config.oci_compartment_id or ""
|
||||
|
||||
return {
|
||||
"http_client": DefaultAsyncHttpxClient(
|
||||
auth=auth,
|
||||
headers={
|
||||
"CompartmentId": compartment_id,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
def _get_oci_signer(self) -> oci.signer.AbstractBaseSigner | None:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
return oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||
return None
|
||||
|
||||
def _get_oci_config(self) -> dict:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
config = {"region": self.config.oci_region}
|
||||
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
|
||||
config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile)
|
||||
if not config.get("region"):
|
||||
raise ValueError(
|
||||
"Region not specified in config. Please specify in config or with OCI_REGION env variable."
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def _get_auth(self) -> httpx.Auth:
|
||||
if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL:
|
||||
return OciInstancePrincipalAuth()
|
||||
elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE:
|
||||
return OciUserPrincipalAuth(
|
||||
config_file=self.config.oci_config_file_path, profile_name=self.config.oci_config_profile
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid OCI authentication type: {self.config.oci_auth_type}")
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
List available models from OCI Generative AI service.
|
||||
"""
|
||||
oci_config = self._get_oci_config()
|
||||
oci_signer = self._get_oci_signer()
|
||||
compartment_id = self.config.oci_compartment_id or ""
|
||||
|
||||
if oci_signer is None:
|
||||
client = GenerativeAiClient(config=oci_config)
|
||||
else:
|
||||
client = GenerativeAiClient(config=oci_config, signer=oci_signer)
|
||||
|
||||
models: ModelCollection = client.list_models(
|
||||
compartment_id=compartment_id, capability=MODEL_CAPABILITIES, lifecycle_state="ACTIVE"
|
||||
).data
|
||||
|
||||
seen_models = set()
|
||||
model_ids = []
|
||||
for model in models.items:
|
||||
if model.time_deprecated or model.time_on_demand_retired:
|
||||
continue
|
||||
|
||||
if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities:
|
||||
continue
|
||||
|
||||
# Use display_name + model_type as the key to avoid conflicts
|
||||
model_key = (model.display_name, ModelType.llm)
|
||||
if model_key in seen_models:
|
||||
continue
|
||||
|
||||
seen_models.add(model_key)
|
||||
model_ids.append(model.display_name)
|
||||
|
||||
return model_ids
|
||||
|
||||
async def openai_embeddings(self, params: OpenAIEmbeddingsRequestWithExtraBody) -> OpenAIEmbeddingsResponse:
|
||||
# The constructed url is a mask that hits OCI's "chat" action, which is not supported for embeddings.
|
||||
raise NotImplementedError("OCI Provider does not (currently) support embeddings")
|
||||
|
|
@ -131,7 +131,6 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||
self.vector_store_table = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
|
|
@ -190,9 +189,16 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
# Try to load from kvstore
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if not vector_store_data:
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Llama Stack")
|
||||
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
collection = await maybe_await(self.client.get_collection(vector_store_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
|
||||
|
|
|
|||
|
|
@ -328,13 +328,16 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
# Try to load from kvstore
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if not vector_store_data:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
|
||||
|
|
|
|||
|
|
@ -368,6 +368,22 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|||
log.exception("Could not connect to PGVector database server")
|
||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||
|
||||
# Load existing vector stores from KV store into cache
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for vector_store_data in stored_vector_stores:
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
pgvector_index = PGVectorIndex(
|
||||
vector_store=vector_store,
|
||||
dimension=vector_store.embedding_dimension,
|
||||
conn=self.conn,
|
||||
kvstore=self.kvstore,
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
|
||||
self.cache[vector_store.identifier] = index
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
|
|
@ -377,7 +393,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|||
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
assert self.kvstore is not None
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
|
||||
|
||||
# Save to kvstore for persistence
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
# Upsert model metadata in Postgres
|
||||
upsert_models(self.conn, [(vector_store.identifier, vector_store)])
|
||||
|
||||
|
|
@ -396,7 +418,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|||
del self.cache[vector_store_id]
|
||||
|
||||
# Delete vector DB metadata from KV store
|
||||
assert self.kvstore is not None
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before unregistering vector stores.")
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def insert_chunks(self, vector_store_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
|
|
@ -413,13 +436,16 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProt
|
|||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
# Try to load from kvstore
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if not vector_store_data:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
|
||||
await index.initialize()
|
||||
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
|
||||
|
|
|
|||
|
|
@ -183,7 +183,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
await super().shutdown()
|
||||
|
||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||
assert self.kvstore is not None
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_store.model_dump_json())
|
||||
|
||||
|
|
@ -200,20 +201,24 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
|||
await self.cache[vector_store_id].index.delete()
|
||||
del self.cache[vector_store_id]
|
||||
|
||||
assert self.kvstore is not None
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
||||
await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
|
||||
|
||||
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
|
||||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_store_table is None:
|
||||
raise ValueError(f"Vector DB not found {vector_store_id}")
|
||||
# Try to load from kvstore
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
||||
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if not vector_store_data:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
index = VectorStoreWithIndex(
|
||||
vector_store=vector_store,
|
||||
index=QdrantIndex(client=self.client, collection_name=vector_store.identifier),
|
||||
|
|
|
|||
|
|
@ -346,13 +346,16 @@ class WeaviateVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, NeedsRequestProv
|
|||
if vector_store_id in self.cache:
|
||||
return self.cache[vector_store_id]
|
||||
|
||||
if self.vector_store_table is None:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
|
||||
if not vector_store:
|
||||
# Try to load from kvstore
|
||||
if self.kvstore is None:
|
||||
raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_store_id}"
|
||||
vector_store_data = await self.kvstore.get(key)
|
||||
if not vector_store_data:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_store.identifier, weaviate_format=True)
|
||||
if not client.collections.exists(sanitized_collection_name):
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreContent,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileLastError,
|
||||
|
|
@ -921,22 +921,21 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
) -> VectorStoreFileContentResponse:
|
||||
"""Retrieves the contents of a vector store file."""
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
|
||||
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||
content = []
|
||||
for chunk in chunks:
|
||||
content.extend(self._chunk_to_vector_store_content(chunk))
|
||||
return VectorStoreFileContentsResponse(
|
||||
file_id=file_id,
|
||||
filename=file_info.get("filename", ""),
|
||||
attributes=file_info.get("attributes", {}),
|
||||
content=content,
|
||||
return VectorStoreFileContentResponse(
|
||||
object="vector_store.file_content.page",
|
||||
data=content,
|
||||
has_more=False,
|
||||
next_page=None,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue