mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-19 03:29:39 +00:00
Merge branch 'main' into henrytu/cerebras-integration
This commit is contained in:
commit
c29e3271d3
38 changed files with 523 additions and 139 deletions
5
llama_stack/providers/remote/datasetio/__init__.py
Normal file
5
llama_stack/providers/remote/datasetio/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -3,12 +3,13 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
|||
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
|
|||
|
|
@ -35,7 +35,9 @@ class NVIDIAConfig(BaseModel):
|
|||
"""
|
||||
|
||||
url: str = Field(
|
||||
default="https://integrate.api.nvidia.com",
|
||||
default_factory=lambda: os.getenv(
|
||||
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"
|
||||
),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
|
|||
|
|
@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
|
|
@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||
request, self.formatter
|
||||
request, self.register_helper.get_llama_model(request.model), self.formatter
|
||||
)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue