Fix incorrect completion() signature for Databricks provider (#236)

This commit is contained in:
Yuan Tang 2024-10-11 11:47:57 -04:00 committed by GitHub
parent 9fbe8852aa
commit 2128e61da2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 19 additions and 6 deletions

View file

@ -7,6 +7,7 @@
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps): async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance( assert isinstance(
config, DatabricksImplConfig config, DatabricksImplConfig

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View file

@ -48,7 +48,14 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def completion(self, request: CompletionRequest) -> AsyncGenerator: def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( def chat_completion(

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any from typing import Any
from .config import VLLMConfig from .config import VLLMConfig