Merge branch 'main' into groq

This commit is contained in:
Swan Htet Aung 2024-11-26 12:28:31 -06:00 committed by GitHub
commit bc427b3081
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 43 additions and 8 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -3,12 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class HuggingfaceDatasetIOConfig(BaseModel):

View file

@ -9,6 +9,7 @@ from llama_stack.apis.datasetio import * # noqa: F403
import datasets as hf_datasets
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl

View file

@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model_id,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model_id,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
request, self.register_helper.get_llama_model(request.model), self.formatter
)
return dict(
prompt=prompt,