This commit is contained in:
Botao Chen 2025-03-11 20:37:42 -07:00
parent ca2922a455
commit cb42e1d9d4
2 changed files with 57 additions and 30 deletions

View file

@ -545,7 +545,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
elif delta.type == "text": elif delta.type == "text":
delta.text = "hello" # delta.text = "hello"
content += delta.text content += delta.text
if stream and event.stop_reason is None: if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack_client import AsyncLlamaStackClient from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
@ -24,6 +26,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig from .config import PassthroughImplConfig
@ -120,10 +123,14 @@ class PassthroughInferenceAdapter(Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
reqeust_params = { # TODO: revisit this remove tool_calls from messages logic
for message in messages:
if hasattr(message, "tool_calls"):
message.tool_calls = None
request_params = {
"model_id": model.provider_resource_id, "model_id": model.provider_resource_id,
"messages": messages, "messages": messages,
"sampling_params": sampling_params, "sampling_params": sampling_params,
@ -134,34 +141,35 @@ class PassthroughInferenceAdapter(Inference):
"stream": stream, "stream": stream,
"logprobs": logprobs, "logprobs": logprobs,
} }
request_params = {key: value for key, value in reqeust_params.items() if value is not None}
json_params = {}
from llama_stack.distribution.library_client import (
convert_pydantic_to_json_value,
)
# cast everything to json dict
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
# if key != "tools":
json_params[key] = json_input
# only pass through the not None params # only pass through the not None params
return await client.inference.chat_completion(**json_params) request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
if stream:
return self._stream_chat_completion(json_params)
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
async for chunk in stream_response:
chunk = chunk.to_dict()
chunk["metrics"] = []
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
async def embeddings( async def embeddings(
self, self,
@ -181,3 +189,22 @@ class PassthroughInferenceAdapter(Inference):
output_dimension=output_dimension, output_dimension=output_dimension,
task_type=task_type, task_type=task_type,
) )
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params