temp commit

This commit is contained in:
Botao Chen 2025-03-11 15:40:35 -07:00
parent e3edca7739
commit ca2922a455
3 changed files with 44 additions and 9 deletions

View file

@ -545,6 +545,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
elif delta.type == "text": elif delta.type == "text":
delta.text = "hello"
content += delta.text content += delta.text
if stream and event.stop_reason is None: if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(

View file

@ -6,7 +6,7 @@
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from llama_stack_client import LlamaStackClient from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -46,7 +46,7 @@ class PassthroughInferenceAdapter(Inference):
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
return model return model
def _get_client(self) -> LlamaStackClient: def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None passthrough_url = None
passthrough_api_key = None passthrough_api_key = None
provider_data = None provider_data = None
@ -71,7 +71,7 @@ class PassthroughInferenceAdapter(Inference):
) )
passthrough_api_key = provider_data.passthrough_api_key passthrough_api_key = provider_data.passthrough_api_key
return LlamaStackClient( return AsyncLlamaStackClient(
base_url=passthrough_url, base_url=passthrough_url,
api_key=passthrough_api_key, api_key=passthrough_api_key,
provider_data=provider_data, provider_data=provider_data,
@ -103,7 +103,7 @@ class PassthroughInferenceAdapter(Inference):
params = {key: value for key, value in params.items() if value is not None} params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params # only pass through the not None params
return client.inference.completion(**params) return await client.inference.completion(**params)
async def chat_completion( async def chat_completion(
self, self,
@ -123,7 +123,7 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client() client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
params = { reqeust_params = {
"model_id": model.provider_resource_id, "model_id": model.provider_resource_id,
"messages": messages, "messages": messages,
"sampling_params": sampling_params, "sampling_params": sampling_params,
@ -135,10 +135,33 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs, "logprobs": logprobs,
} }
params = {key: value for key, value in params.items() if value is not None}
request_params = {key: value for key, value in reqeust_params.items() if value is not None}
json_params = {}
from llama_stack.distribution.library_client import (
convert_pydantic_to_json_value,
)
# cast everything to json dict
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
# if key != "tools":
json_params[key] = json_input
# only pass through the not None params # only pass through the not None params
return client.inference.chat_completion(**params) return await client.inference.chat_completion(**json_params)
async def embeddings( async def embeddings(
self, self,
@ -151,7 +174,7 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client() client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
return client.inference.embeddings( return await client.inference.embeddings(
model_id=model.provider_resource_id, model_id=model.provider_resource_id,
contents=contents, contents=contents,
text_truncation=text_truncation, text_truncation=text_truncation,

View file

@ -20,6 +20,13 @@ providers:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
- provider_id: meta-reference-inference
provider_type: inline::meta-reference
config:
model: meta-llama/Llama-Guard-3-1B
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
# api_key: ${env.TOGETHER_API_KEY}
vector_io: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -103,8 +110,12 @@ models:
provider_id: passthrough provider_id: passthrough
provider_model_id: llama3.2-11b-vision-instruct provider_model_id: llama3.2-11b-vision-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-1B
provider_id: meta-reference-inference
model_type: llm
shields: shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-1B
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []