temp commit

This commit is contained in:
Botao Chen 2025-03-11 15:40:35 -07:00
parent e3edca7739
commit ca2922a455
3 changed files with 44 additions and 9 deletions

View file

@ -545,6 +545,7 @@ class ChatAgent(ShieldRunnerMixin):
)
elif delta.type == "text":
delta.text = "hello"
content += delta.text
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(

View file

@ -6,7 +6,7 @@
from typing import AsyncGenerator, List, Optional
from llama_stack_client import LlamaStackClient
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
@ -46,7 +46,7 @@ class PassthroughInferenceAdapter(Inference):
async def register_model(self, model: Model) -> Model:
return model
def _get_client(self) -> LlamaStackClient:
def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None
passthrough_api_key = None
provider_data = None
@ -71,7 +71,7 @@ class PassthroughInferenceAdapter(Inference):
)
passthrough_api_key = provider_data.passthrough_api_key
return LlamaStackClient(
return AsyncLlamaStackClient(
base_url=passthrough_url,
api_key=passthrough_api_key,
provider_data=provider_data,
@ -103,7 +103,7 @@ class PassthroughInferenceAdapter(Inference):
params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params
return client.inference.completion(**params)
return await client.inference.completion(**params)
async def chat_completion(
self,
@ -123,7 +123,7 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
reqeust_params = {
"model_id": model.provider_resource_id,
"messages": messages,
"sampling_params": sampling_params,
@ -134,11 +134,34 @@ class PassthroughInferenceAdapter(Inference):
"stream": stream,
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
request_params = {key: value for key, value in reqeust_params.items() if value is not None}
json_params = {}
from llama_stack.distribution.library_client import (
convert_pydantic_to_json_value,
)
# cast everything to json dict
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
# if key != "tools":
json_params[key] = json_input
# only pass through the not None params
return client.inference.chat_completion(**params)
return await client.inference.chat_completion(**json_params)
async def embeddings(
self,
@ -151,7 +174,7 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
model = await self.model_store.get_model(model_id)
return client.inference.embeddings(
return await client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
text_truncation=text_truncation,

View file

@ -20,6 +20,13 @@ providers:
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
- provider_id: meta-reference-inference
provider_type: inline::meta-reference
config:
model: meta-llama/Llama-Guard-3-1B
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
# api_key: ${env.TOGETHER_API_KEY}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
@ -103,8 +110,12 @@ models:
provider_id: passthrough
provider_model_id: llama3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-1B
provider_id: meta-reference-inference
model_type: llm
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
- shield_id: meta-llama/Llama-Guard-3-1B
vector_dbs: []
datasets: []
scoring_fns: []