Merge branch 'main' into agent_session_unit_test

This commit is contained in:
Francisco Arceo 2025-08-13 08:39:53 -06:00 committed by GitHub
commit ef5b918996
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 15 additions and 9 deletions

View file

@ -11,7 +11,7 @@ on:
- synchronize
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.event.pull_request.number }}
cancel-in-progress: true
permissions:

View file

@ -16,6 +16,7 @@ from llama_stack.distributions.template import DistributionTemplate, RunConfigSe
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
def get_distribution_template() -> DistributionTemplate:
@ -71,9 +72,10 @@ def get_distribution_template() -> DistributionTemplate:
chromadb_provider = Provider(
provider_id="chromadb",
provider_type="remote::chromadb",
config={
"url": "${env.CHROMA_URL}",
},
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
url="${env.CHROMADB_URL:=}",
),
)
inference_model = ModelInput(

View file

@ -26,7 +26,10 @@ providers:
- provider_id: chromadb
provider_type: remote::chromadb
config:
url: ${env.CHROMA_URL}
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -22,7 +22,10 @@ providers:
- provider_id: chromadb
provider_type: remote::chromadb
config:
url: ${env.CHROMA_URL}
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter):
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(
model=config.url,
)
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]