mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 21:04:39 +00:00
Merge branch 'main' into fix/embedding-model-type
This commit is contained in:
commit
309f06829c
59 changed files with 1005 additions and 339 deletions
|
@ -527,7 +527,7 @@ class InferenceRouter(Inference):
|
|||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -855,4 +855,4 @@ class InferenceRouter(Inference):
|
|||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
await self.store.store_chat_completion(final_response, messages)
|
||||
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||
|
|
|
@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
|
||||
provider = self.impls_by_provider_id[provider_id]
|
||||
logger.warning(
|
||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||
)
|
||||
vector_store = await provider.openai_create_vector_store(
|
||||
name=vector_db_name or vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=model.metadata["embedding_dimension"],
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=provider_vector_db_id,
|
||||
)
|
||||
|
||||
vector_store_id = vector_store.id
|
||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||
logger.warning(
|
||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||
)
|
||||
|
||||
vector_db_data = {
|
||||
"identifier": vector_db_id,
|
||||
"identifier": vector_store_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_vector_db_id,
|
||||
"provider_resource_id": actual_provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_db_name,
|
||||
"vector_db_name": vector_store.name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
|
|
|
@ -132,15 +132,17 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
},
|
||||
)
|
||||
elif isinstance(exc, ConflictError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
|
||||
elif isinstance(exc, ResourceNotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError | AccessDeniedError):
|
||||
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, ConnectionError | httpx.ConnectError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
|
||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
|
|
|
@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
|
|||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
template = get_starter_distribution_template()
|
||||
name = "ci-tests"
|
||||
template.name = name
|
||||
template = get_starter_distribution_template(name="ci-tests")
|
||||
template.description = "CI tests for Llama Stack"
|
||||
|
||||
return template
|
||||
|
|
|
@ -89,28 +89,28 @@ providers:
|
|||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
|
@ -121,15 +121,15 @@ providers:
|
|||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -89,28 +89,28 @@ providers:
|
|||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/faiss_store.db
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec_registry.db
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/milvus_registry.db
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu/}/chroma_remote_registry.db
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
|
@ -121,15 +121,15 @@ providers:
|
|||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
|
|||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
template = get_starter_distribution_template()
|
||||
name = "starter-gpu"
|
||||
template.name = name
|
||||
template = get_starter_distribution_template(name="starter-gpu")
|
||||
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
|
||||
|
||||
template.providers["post_training"] = [
|
||||
|
|
|
@ -99,9 +99,8 @@ def get_remote_inference_providers() -> list[Provider]:
|
|||
return inference_providers
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||
remote_inference_providers = get_remote_inference_providers()
|
||||
name = "starter"
|
||||
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
||||
|
|
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
|
|||
)
|
||||
valid = False
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]:
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]
|
||||
else: # /v1/completions
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
errors.append(
|
||||
BatchError(
|
||||
|
@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
try:
|
||||
# TODO(SECURITY): review body for security issues
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
if request.url == "/v1/chat/completions":
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(completion_response, "model_dump_json"), (
|
||||
"Completion response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id,
|
||||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
|
@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
|
|||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -5,10 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import mimetypes
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import UploadFile
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
|
@ -30,13 +36,18 @@ from llama_stack.apis.tools import (
|
|||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreChunkingStrategyStaticConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
parse_data_url,
|
||||
)
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
|
@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
config: RagToolRuntimeConfig,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
):
|
||||
self.config = config
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
chunks = []
|
||||
if not documents:
|
||||
return
|
||||
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
chunk_size_in_tokens,
|
||||
chunk_size_in_tokens // 4,
|
||||
doc.metadata,
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
||||
mime_type = parts["mimetype"]
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(doc.content.uri)
|
||||
file_data = response.content
|
||||
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
||||
else:
|
||||
content_str = await content_from_doc(doc)
|
||||
file_data = content_str.encode("utf-8")
|
||||
mime_type = doc.mime_type or "text/plain"
|
||||
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
await self.vector_io_api.insert_chunks(
|
||||
chunks=chunks,
|
||||
vector_db_id=vector_db_id,
|
||||
)
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
|
|
|
@ -116,7 +116,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai<=0.18.0",
|
||||
"fireworks-ai<=0.17.16",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
|
@ -207,7 +207,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="gemini",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=["litellm", "openai"],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
|
@ -270,7 +270,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm"],
|
||||
pip_packages=["litellm", "openai"],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
|
|
|
@ -32,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.inference],
|
||||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
|
|
@ -5,12 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: GeminiConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
|
@ -21,6 +22,11 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
)
|
||||
self.config = config
|
||||
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
|
|
|
@ -4,13 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
"""
|
||||
SambaNova Inference Adapter for Llama Stack.
|
||||
|
||||
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||
is used instead of LiteLLMOpenAIMixin.check_model_availability().
|
||||
|
||||
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
|
||||
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
|
||||
"""
|
||||
|
||||
def __init__(self, config: SambaNovaImplConfig):
|
||||
self.config = config
|
||||
self.environment_available_models = []
|
||||
|
@ -24,3 +37,14 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
download_images=True, # SambaNova requires base64 image encoding
|
||||
json_schema_strict=False, # SambaNova doesn't support strict=True yet
|
||||
)
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: The SambaNova base URL
|
||||
"""
|
||||
return self.config.url
|
||||
|
|
|
@ -4,53 +4,55 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
aws_access_key_id: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_PROFILE"),
|
||||
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: int | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
|
||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||
)
|
||||
retry_mode: str | None = Field(
|
||||
default=None,
|
||||
default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
|
||||
description="A string representing the type of retries Boto3 will perform."
|
||||
"Default use environment variable: AWS_RETRY_MODE",
|
||||
)
|
||||
connect_timeout: float | None = Field(
|
||||
default=60,
|
||||
default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
read_timeout: float | None = Field(
|
||||
default=60,
|
||||
default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
session_ttl: int | None = Field(
|
||||
default=3600,
|
||||
default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
|
||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin:
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = embedding_model.encode(
|
||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
||||
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = await asyncio.to_thread(
|
||||
embedding_model.encode,
|
||||
[interleaved_content_as_str(content) for content in contents],
|
||||
show_progress_bar=False,
|
||||
)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
|
@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin:
|
|||
|
||||
# Get the model and generate embeddings
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
|
||||
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
||||
|
||||
# Convert embeddings to the requested format
|
||||
data = []
|
||||
|
@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin:
|
|||
usage=usage,
|
||||
)
|
||||
|
||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
|
@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin:
|
|||
return loaded_model
|
||||
|
||||
log.info(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
loaded_model = SentenceTransformer(model)
|
||||
def _load_model():
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
return SentenceTransformer(model)
|
||||
|
||||
loaded_model = await asyncio.to_thread(_load_model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
|
|
@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
|||
raise AuthenticationRequiredError(exc) from exc
|
||||
if i == len(connection_strategies) - 1:
|
||||
raise
|
||||
except* httpx.ConnectError as eg:
|
||||
# Connection refused, server down, network unreachable
|
||||
if i == len(connection_strategies) - 1:
|
||||
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
|
||||
logger.error(f"MCP connection error: {error_msg}")
|
||||
raise ConnectionError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* httpx.TimeoutException as eg:
|
||||
# Request timeout, server too slow
|
||||
if i == len(connection_strategies) - 1:
|
||||
error_msg = f"MCP server at {endpoint} timed out"
|
||||
logger.error(f"MCP timeout error: {error_msg}")
|
||||
raise TimeoutError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* httpx.RequestError as eg:
|
||||
# DNS resolution failures, network errors, invalid URLs
|
||||
if i == len(connection_strategies) - 1:
|
||||
# Get the first exception's message for the error string
|
||||
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
|
||||
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
|
||||
logger.error(f"MCP network error: {error_msg}")
|
||||
raise ConnectionError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* McpError:
|
||||
if i < len(connection_strategies) - 1:
|
||||
logger.warning(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue