formatting

This commit is contained in:
Dalton Flanagan 2024-08-14 17:03:43 -04:00
parent 069d877210
commit b311dcd143
22 changed files with 82 additions and 128 deletions

View file

@ -101,26 +101,22 @@ class Inference(Protocol):
async def completion(
self,
request: CompletionRequest,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
...
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
...
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse:
...
) -> BatchCompletionResponse: ...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse:
...
) -> BatchChatCompletionResponse: ...

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_toolchain.inference.api import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
from termcolor import cprint
class LogEvent:

View file

@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models
from pydantic import BaseModel, Field, field_validator
from llama_toolchain.inference.api import QuantizationConfig
from pydantic import BaseModel, Field, field_validator
@json_schema_type
class MetaReferenceImplConfig(BaseModel):

View file

@ -28,10 +28,10 @@ from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.llama3_1.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType
from termcolor import cprint
from .config import MetaReferenceImplConfig