pre-commit fixes

This commit is contained in:
Chantal D Gama Rose 2025-03-14 13:56:05 -07:00
parent 967dd0aa08
commit 7e211f8553
314 changed files with 5574 additions and 11369 deletions

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack_client import LlamaStackClient
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -24,6 +26,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig
@ -46,7 +49,7 @@ class PassthroughInferenceAdapter(Inference):
async def register_model(self, model: Model) -> Model:
return model
def _get_client(self) -> LlamaStackClient:
def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None
passthrough_api_key = None
provider_data = None
@ -71,7 +74,7 @@ class PassthroughInferenceAdapter(Inference):
)
passthrough_api_key = provider_data.passthrough_api_key
return LlamaStackClient(
return AsyncLlamaStackClient(
base_url=passthrough_url,
api_key=passthrough_api_key,
provider_data=provider_data,
@ -81,15 +84,17 @@ class PassthroughInferenceAdapter(Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
request_params = {
"model_id": model.provider_resource_id,
"content": content,
"sampling_params": sampling_params,
@ -98,16 +103,19 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
# only pass through the not None params
return client.inference.completion(**params)
return await client.inference.completion(**json_params)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -116,10 +124,16 @@ class PassthroughInferenceAdapter(Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
client = self._get_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
params = {
# TODO: revisit this remove tool_calls from messages logic
for message in messages:
if hasattr(message, "tool_calls"):
message.tool_calls = None
request_params = {
"model_id": model.provider_resource_id,
"messages": messages,
"sampling_params": sampling_params,
@ -131,10 +145,39 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params
return client.inference.chat_completion(**params)
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
if stream:
return self._stream_chat_completion(json_params)
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
# temporary hack to remove the metrics from the response
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
async for chunk in stream_response:
chunk = chunk.to_dict()
# temporary hack to remove the metrics from the response
chunk["metrics"] = []
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
async def embeddings(
self,
@ -147,10 +190,29 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
model = await self.model_store.get_model(model_id)
return client.inference.embeddings(
return await client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params