diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index aa8a87bf7..8f3a0d147 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -4,12 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack_client import LlamaStackClient +from llama_stack_client import AsyncLlamaStackClient from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, Inference, @@ -24,6 +26,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from .config import PassthroughImplConfig @@ -46,7 +49,7 @@ class PassthroughInferenceAdapter(Inference): async def register_model(self, model: Model) -> Model: return model - def _get_client(self) -> LlamaStackClient: + def _get_client(self) -> AsyncLlamaStackClient: passthrough_url = None passthrough_api_key = None provider_data = None @@ -71,7 +74,7 @@ class PassthroughInferenceAdapter(Inference): ) passthrough_api_key = provider_data.passthrough_api_key - return LlamaStackClient( + return AsyncLlamaStackClient( base_url=passthrough_url, api_key=passthrough_api_key, provider_data=provider_data, @@ -91,7 +94,7 @@ class PassthroughInferenceAdapter(Inference): client = self._get_client() model = await self.model_store.get_model(model_id) - params = { + request_params = { "model_id": model.provider_resource_id, "content": content, "sampling_params": sampling_params, @@ -100,10 +103,13 @@ class PassthroughInferenceAdapter(Inference): "logprobs": logprobs, } - params = {key: value for key, value in params.items() if value is not None} + request_params = {key: value for key, value in request_params.items() if value is not None} + + # cast everything to json dict + json_params = self.cast_value_to_json_dict(request_params) # only pass through the not None params - return client.inference.completion(**params) + return await client.inference.completion(**json_params) async def chat_completion( self, @@ -120,10 +126,14 @@ class PassthroughInferenceAdapter(Inference): ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() - client = self._get_client() model = await self.model_store.get_model(model_id) - params = { + # TODO: revisit this remove tool_calls from messages logic + for message in messages: + if hasattr(message, "tool_calls"): + message.tool_calls = None + + request_params = { "model_id": model.provider_resource_id, "messages": messages, "sampling_params": sampling_params, @@ -135,10 +145,39 @@ class PassthroughInferenceAdapter(Inference): "logprobs": logprobs, } - params = {key: value for key, value in params.items() if value is not None} - # only pass through the not None params - return client.inference.chat_completion(**params) + request_params = {key: value for key, value in request_params.items() if value is not None} + + # cast everything to json dict + json_params = self.cast_value_to_json_dict(request_params) + + if stream: + return self._stream_chat_completion(json_params) + else: + return await self._nonstream_chat_completion(json_params) + + async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse: + client = self._get_client() + response = await client.inference.chat_completion(**json_params) + + response = response.to_dict() + + # temporary hack to remove the metrics from the response + response["metrics"] = [] + + return convert_to_pydantic(ChatCompletionResponse, response) + + async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator: + client = self._get_client() + stream_response = await client.inference.chat_completion(**json_params) + + async for chunk in stream_response: + chunk = chunk.to_dict() + + # temporary hack to remove the metrics from the response + chunk["metrics"] = [] + chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk) + yield chunk async def embeddings( self, @@ -151,10 +190,29 @@ class PassthroughInferenceAdapter(Inference): client = self._get_client() model = await self.model_store.get_model(model_id) - return client.inference.embeddings( + return await client.inference.embeddings( model_id=model.provider_resource_id, contents=contents, text_truncation=text_truncation, output_dimension=output_dimension, task_type=task_type, ) + + def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]: + json_params = {} + for key, value in request_params.items(): + json_input = convert_pydantic_to_json_value(value) + if isinstance(json_input, dict): + json_input = {k: v for k, v in json_input.items() if v is not None} + elif isinstance(json_input, list): + json_input = [x for x in json_input if x is not None] + new_input = [] + for x in json_input: + if isinstance(x, dict): + x = {k: v for k, v in x.items() if v is not None} + new_input.append(x) + json_input = new_input + + json_params[key] = json_input + + return json_params