fix linting error

This commit is contained in:
Ishaan Jaff 2024-07-16 21:21:50 -07:00
parent f02f3a7713
commit b04d20d367

View file

@ -1,14 +1,19 @@
import os, types
import copy
import json
from enum import Enum
import requests, copy # type: ignore
import os
import time
from typing import Callable, Optional, List
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import types
from enum import Enum
from typing import Callable, List, Optional
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class TritonError(Exception):
@ -126,8 +131,12 @@ class TritonChatCompletion(BaseLLM):
)
@staticmethod
def split_embedding_by_shape(data: list[float], shape: list[int]) -> list[list[float]]:
def split_embedding_by_shape(
data: List[float], shape: List[int]
) -> List[List[float]]:
if len(shape) != 2:
raise ValueError("Shape must be of length 2.")
embedding_size = shape[1]
return [data[i * embedding_size: (i + 1) * embedding_size] for i in range(shape[0])]
return [
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
]