Merge branch 'main' into agent_session_unit_test

This commit is contained in:
Francisco Arceo 2025-08-13 08:39:53 -06:00 committed by GitHub
commit ef5b918996
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 15 additions and 9 deletions

View file

@ -11,7 +11,7 @@ on:
- synchronize - synchronize
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
cancel-in-progress: true cancel-in-progress: true
permissions: permissions:

View file

@ -16,6 +16,7 @@ from llama_stack.distributions.template import DistributionTemplate, RunConfigSe
from llama_stack.providers.inline.inference.sentence_transformers import ( from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
@ -71,9 +72,10 @@ def get_distribution_template() -> DistributionTemplate:
chromadb_provider = Provider( chromadb_provider = Provider(
provider_id="chromadb", provider_id="chromadb",
provider_type="remote::chromadb", provider_type="remote::chromadb",
config={ config=ChromaVectorIOConfig.sample_run_config(
"url": "${env.CHROMA_URL}", f"~/.llama/distributions/{name}/",
}, url="${env.CHROMADB_URL:=}",
),
) )
inference_model = ModelInput( inference_model = ModelInput(

View file

@ -26,7 +26,10 @@ providers:
- provider_id: chromadb - provider_id: chromadb
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
url: ${env.CHROMA_URL} url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -22,7 +22,10 @@ providers:
- provider_id: chromadb - provider_id: chromadb
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
url: ${env.CHROMA_URL} url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter):
if not config.url: if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.") raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}") log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
model=config.url,
)
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"] self.model_id = endpoint_info["model_id"]