mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
commit
This commit is contained in:
parent
ca2922a455
commit
cb42e1d9d4
2 changed files with 57 additions and 30 deletions
|
@ -545,7 +545,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
elif delta.type == "text":
|
elif delta.type == "text":
|
||||||
delta.text = "hello"
|
# delta.text = "hello"
|
||||||
content += delta.text
|
content += delta.text
|
||||||
if stream and event.stop_reason is None:
|
if stream and event.stop_reason is None:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
|
|
@ -4,12 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack_client import AsyncLlamaStackClient
|
from llama_stack_client import AsyncLlamaStackClient
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -24,6 +26,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from .config import PassthroughImplConfig
|
from .config import PassthroughImplConfig
|
||||||
|
@ -120,10 +123,14 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
client = self._get_client()
|
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
reqeust_params = {
|
# TODO: revisit this remove tool_calls from messages logic
|
||||||
|
for message in messages:
|
||||||
|
if hasattr(message, "tool_calls"):
|
||||||
|
message.tool_calls = None
|
||||||
|
|
||||||
|
request_params = {
|
||||||
"model_id": model.provider_resource_id,
|
"model_id": model.provider_resource_id,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"sampling_params": sampling_params,
|
"sampling_params": sampling_params,
|
||||||
|
@ -134,34 +141,35 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"logprobs": logprobs,
|
"logprobs": logprobs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
request_params = {key: value for key, value in reqeust_params.items() if value is not None}
|
|
||||||
|
|
||||||
json_params = {}
|
|
||||||
from llama_stack.distribution.library_client import (
|
|
||||||
convert_pydantic_to_json_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cast everything to json dict
|
|
||||||
for key, value in request_params.items():
|
|
||||||
json_input = convert_pydantic_to_json_value(value)
|
|
||||||
if isinstance(json_input, dict):
|
|
||||||
json_input = {k: v for k, v in json_input.items() if v is not None}
|
|
||||||
elif isinstance(json_input, list):
|
|
||||||
json_input = [x for x in json_input if x is not None]
|
|
||||||
new_input = []
|
|
||||||
for x in json_input:
|
|
||||||
if isinstance(x, dict):
|
|
||||||
x = {k: v for k, v in x.items() if v is not None}
|
|
||||||
new_input.append(x)
|
|
||||||
json_input = new_input
|
|
||||||
|
|
||||||
# if key != "tools":
|
|
||||||
json_params[key] = json_input
|
|
||||||
|
|
||||||
# only pass through the not None params
|
# only pass through the not None params
|
||||||
return await client.inference.chat_completion(**json_params)
|
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||||
|
|
||||||
|
# cast everything to json dict
|
||||||
|
json_params = self.cast_value_to_json_dict(request_params)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(json_params)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(json_params)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
|
||||||
|
client = self._get_client()
|
||||||
|
response = await client.inference.chat_completion(**json_params)
|
||||||
|
response = response.to_dict()
|
||||||
|
response["metrics"] = []
|
||||||
|
|
||||||
|
return convert_to_pydantic(ChatCompletionResponse, response)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
|
||||||
|
client = self._get_client()
|
||||||
|
stream_response = await client.inference.chat_completion(**json_params)
|
||||||
|
|
||||||
|
async for chunk in stream_response:
|
||||||
|
chunk = chunk.to_dict()
|
||||||
|
chunk["metrics"] = []
|
||||||
|
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -181,3 +189,22 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
output_dimension=output_dimension,
|
output_dimension=output_dimension,
|
||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
json_params = {}
|
||||||
|
for key, value in request_params.items():
|
||||||
|
json_input = convert_pydantic_to_json_value(value)
|
||||||
|
if isinstance(json_input, dict):
|
||||||
|
json_input = {k: v for k, v in json_input.items() if v is not None}
|
||||||
|
elif isinstance(json_input, list):
|
||||||
|
json_input = [x for x in json_input if x is not None]
|
||||||
|
new_input = []
|
||||||
|
for x in json_input:
|
||||||
|
if isinstance(x, dict):
|
||||||
|
x = {k: v for k, v in x.items() if v is not None}
|
||||||
|
new_input.append(x)
|
||||||
|
json_input = new_input
|
||||||
|
|
||||||
|
json_params[key] = json_input
|
||||||
|
|
||||||
|
return json_params
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue