From efd842d605fba6e7687315a0f2966895b9f67615 Mon Sep 17 00:00:00 2001 From: Edward Ma Date: Mon, 2 Dec 2024 08:17:22 -0800 Subject: [PATCH] lint --- .../remote/inference/sambanova/sambanova.py | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 2a91a8251..8d38b4d4c 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -11,7 +11,7 @@ from llama_models.datatypes import CoreModelId, SamplingStrategy from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, ImageMedia +from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI @@ -27,9 +27,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_stream_response, ) -from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_message_to_dict, -) +from llama_stack.providers.utils.inference.prompt_adapter import convert_message_to_dict from .config import SambaNovaImplConfig @@ -93,7 +91,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: raise NotImplementedError() - + async def chat_completion( self, model_id: str, @@ -125,7 +123,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): return self._stream_chat_completion(request_sambanova, client) else: return await self._nonstream_chat_completion(request_sambanova, client) - + async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI ) -> ChatCompletionResponse: @@ -145,18 +143,22 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): stream, self.formatter ): yield chunk - + async def embeddings( self, model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() - - async def convert_chat_completion_request(self, request: ChatCompletionRequest) -> dict: + + async def convert_chat_completion_request( + self, request: ChatCompletionRequest + ) -> dict: compatible_request = self.convert_sampling_params(request.sampling_params) compatible_request["model"] = request.model - compatible_request["messages"] = await self.convert_to_sambanova_message(request.messages) + compatible_request["messages"] = await self.convert_to_sambanova_message( + request.messages + ) compatible_request["stream"] = request.stream compatible_request["logprobs"] = False compatible_request["extra_headers"] = { @@ -164,7 +166,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): } return compatible_request - def convert_sampling_params(self, sampling_params: SamplingParams, legacy: bool = False) -> dict: + def convert_sampling_params( + self, sampling_params: SamplingParams, legacy: bool = False + ) -> dict: params = {} if sampling_params: @@ -182,14 +186,14 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): params["extra_body"]["top_k"] = sampling_params.top_k elif sampling_params.strategy == "greedy": params["temperature"] = sampling_params.temperature - + return params async def convert_to_sambanova_message(self, messages: List[Message]) -> List[dict]: conversation = [] for message in messages: content = await convert_message_to_dict(message) - + # Need to override role if isinstance(message, UserMessage): content["role"] = "user" @@ -197,14 +201,16 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): content["role"] = "assistant" tools = [] for tool_call in message.tool_calls: - tools.append({ - "id": tool_call.call_id, - "function": { - "name": tool_call.name, - "arguments": json.dumps(tool_call.arguments), - }, - "type": "function", - }) + tools.append( + { + "id": tool_call.call_id, + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.arguments), + }, + "type": "function", + } + ) content["tool_calls"] = tools elif isinstance(message, ToolResponseMessage): content["role"] = "tool" @@ -215,5 +221,3 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): conversation.append(content) return conversation - - \ No newline at end of file