Merge branch 'main' into fix/embedding-model-type

This commit is contained in:
raghotham 2025-09-06 12:29:54 -07:00 committed by GitHub
commit 309f06829c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
59 changed files with 1005 additions and 339 deletions

View file

@ -527,7 +527,7 @@ class InferenceRouter(Inference):
# Store the response with the ID that will be returned to the client
if self.store:
await self.store.store_chat_completion(response, messages)
asyncio.create_task(self.store.store_chat_completion(response, messages))
if self.telemetry:
metrics = self._construct_metrics(
@ -855,4 +855,4 @@ class InferenceRouter(Inference):
object="chat.completion",
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
await self.store.store_chat_completion(final_response, messages)
asyncio.create_task(self.store.store_chat_completion(final_response, messages))

View file

@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = {
"identifier": vector_db_id,
"identifier": vector_store_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": provider_vector_db_id,
"provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name,
"vector_db_name": vector_store.name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)

View file

@ -132,15 +132,17 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
},
)
elif isinstance(exc, ConflictError):
return HTTPException(status_code=409, detail=str(exc))
return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
elif isinstance(exc, ResourceNotFoundError):
return HTTPException(status_code=404, detail=str(exc))
return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
elif isinstance(exc, ValueError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, ConnectionError | httpx.ConnectError):
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):

View file

@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template()
name = "ci-tests"
template.name = name
template = get_starter_distribution_template(name="ci-tests")
template.description = "CI tests for Llama Stack"
return template

View file

@ -89,28 +89,28 @@ providers:
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
@ -121,15 +121,15 @@ providers:
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -89,28 +89,28 @@ providers:
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/faiss_store.db
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
@ -121,15 +121,15 @@ providers:
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template()
name = "starter-gpu"
template.name = name
template = get_starter_distribution_template(name="starter-gpu")
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
template.providers["post_training"] = [

View file

@ -99,9 +99,8 @@ def get_remote_inference_providers() -> list[Provider]:
return inference_providers
def get_distribution_template() -> DistributionTemplate:
def get_distribution_template(name: str = "starter") -> DistributionTemplate:
remote_inference_providers = get_remote_inference_providers()
name = "starter"
providers = {
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]

View file

@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]:
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
)
valid = False
for param, expected_type, type_string in [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]:
if batch.endpoint == "/v1/chat/completions":
required_params = [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
for param, expected_type, type_string in required_params:
if param not in body:
errors.append(
BatchError(
@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches):
try:
# TODO(SECURITY): review body for security issues
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(),
},
}
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(),
},
}
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
"Completion response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id,
"body": completion_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {

View file

@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize()
return impl

View file

@ -5,10 +5,15 @@
# the root directory of this source tree.
import asyncio
import base64
import io
import mimetypes
import secrets
import string
from typing import Any
import httpx
from fastapi import UploadFile
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
ListToolDefsResponse,
@ -30,13 +36,18 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
parse_data_url,
)
from .config import RagToolRuntimeConfig
@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
config: RagToolRuntimeConfig,
vector_io_api: VectorIO,
inference_api: Inference,
files_api: Files,
):
self.config = config
self.vector_io_api = vector_io_api
self.inference_api = inference_api
self.files_api = files_api
async def initialize(self):
pass
@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
chunks = []
if not documents:
return
for doc in documents:
content = await content_from_doc(doc)
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
doc.metadata,
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
mime_type = parts["mimetype"]
else:
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
if not chunks:
return
await self.vector_io_api.insert_chunks(
chunks=chunks,
vector_db_id=vector_db_id,
)
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
async def query(
self,

View file

@ -116,7 +116,7 @@ def available_providers() -> list[ProviderSpec]:
adapter=AdapterSpec(
adapter_type="fireworks",
pip_packages=[
"fireworks-ai<=0.18.0",
"fireworks-ai<=0.17.16",
],
module="llama_stack.providers.remote.inference.fireworks",
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
@ -207,7 +207,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="gemini",
pip_packages=["litellm"],
pip_packages=["litellm", "openai"],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@ -270,7 +270,7 @@ Available Models:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm"],
pip_packages=["litellm", "openai"],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",

View file

@ -32,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference],
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),
remote_provider_spec(

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
from .models import MODEL_ENTRIES
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
@ -21,6 +22,11 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
)
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def initialize(self) -> None:
await super().initialize()

View file

@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
SambaNova Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of LiteLLMOpenAIMixin.check_model_availability().
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
"""
def __init__(self, config: SambaNovaImplConfig):
self.config = config
self.environment_available_models = []
@ -24,3 +37,14 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
download_images=True, # SambaNova requires base64 image encoding
json_schema_strict=False, # SambaNova doesn't support strict=True yet
)
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.
:return: The SambaNova base URL
"""
return self.config.url

View file

@ -4,53 +4,55 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pydantic import BaseModel, Field
class BedrockBaseConfig(BaseModel):
aws_access_key_id: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
)
aws_secret_access_key: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
)
aws_session_token: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
)
region_name: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
"Default use environment variable: AWS_DEFAULT_REGION",
)
profile_name: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_PROFILE"),
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
)
total_max_attempts: int | None = Field(
default=None,
default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
description="An integer representing the maximum number of attempts that will be made for a single request, "
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
)
retry_mode: str | None = Field(
default=None,
default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
description="A string representing the type of retries Boto3 will perform."
"Default use environment variable: AWS_RETRY_MODE",
)
connect_timeout: float | None = Field(
default=60,
default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
"The default is 60 seconds.",
)
read_timeout: float | None = Field(
default=60,
default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.",
)
session_ttl: int | None = Field(
default=3600,
default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import struct
from typing import TYPE_CHECKING
@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin:
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = await asyncio.to_thread(
embedding_model.encode,
[interleaved_content_as_str(content) for content in contents],
show_progress_bar=False,
)
return EmbeddingsResponse(embeddings=embeddings)
@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin:
# Get the model and generate embeddings
model_obj = await self.model_store.get_model(model)
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
# Convert embeddings to the requested format
data = []
@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin:
usage=usage,
)
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
def _load_model():
from sentence_transformers import SentenceTransformer
return SentenceTransformer(model)
loaded_model = await asyncio.to_thread(_load_model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model

View file

@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
raise AuthenticationRequiredError(exc) from exc
if i == len(connection_strategies) - 1:
raise
except* httpx.ConnectError as eg:
# Connection refused, server down, network unreachable
if i == len(connection_strategies) - 1:
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
logger.error(f"MCP connection error: {error_msg}")
raise ConnectionError(error_msg) from eg
else:
logger.warning(
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* httpx.TimeoutException as eg:
# Request timeout, server too slow
if i == len(connection_strategies) - 1:
error_msg = f"MCP server at {endpoint} timed out"
logger.error(f"MCP timeout error: {error_msg}")
raise TimeoutError(error_msg) from eg
else:
logger.warning(
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* httpx.RequestError as eg:
# DNS resolution failures, network errors, invalid URLs
if i == len(connection_strategies) - 1:
# Get the first exception's message for the error string
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
logger.error(f"MCP network error: {error_msg}")
raise ConnectionError(error_msg) from eg
else:
logger.warning(
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* McpError:
if i < len(connection_strategies) - 1:
logger.warning(