mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: fix precommit (#1594)
# What does this PR do? - fix precommit [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan CI [//]: # (## Documentation)
This commit is contained in:
parent
90ca4d94de
commit
c7139b0b67
2 changed files with 22 additions and 13 deletions
|
@ -4,8 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BenchmarkInput,
|
BenchmarkInput,
|
||||||
|
@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||||
|
SQLiteVectorIOConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
PGVectorVectorIOConfig,
|
||||||
ProviderModelEntry,
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||||
|
from llama_stack.templates.template import (
|
||||||
|
DistributionTemplate,
|
||||||
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
)
|
)
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
|
||||||
|
|
||||||
|
|
||||||
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
|
||||||
# in this template, we allow each API key to be optional
|
# in this template, we allow each API key to be optional
|
||||||
providers = [
|
providers = [
|
||||||
(
|
(
|
||||||
|
@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="simpleqa",
|
dataset_id="simpleqa",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/simpleqa",
|
"path": "llamastack/simpleqa",
|
||||||
"split": "train",
|
"split": "train",
|
||||||
|
@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="mmlu_cot",
|
dataset_id="mmlu_cot",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/mmlu_cot",
|
"path": "llamastack/mmlu_cot",
|
||||||
"name": "all",
|
"name": "all",
|
||||||
|
@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="gpqa_cot",
|
dataset_id="gpqa_cot",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/gpqa_0shot_cot",
|
"path": "llamastack/gpqa_0shot_cot",
|
||||||
"name": "gpqa_main",
|
"name": "gpqa_main",
|
||||||
|
@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
DatasetInput(
|
DatasetInput(
|
||||||
dataset_id="math_500",
|
dataset_id="math_500",
|
||||||
provider_id="huggingface",
|
provider_id="huggingface",
|
||||||
url={"uri": "https://huggingface.co/datasets/llamastack/math_500"},
|
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/math_500",
|
"path": "llamastack/math_500",
|
||||||
"split": "test",
|
"split": "test",
|
||||||
|
|
|
@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]:
|
def get_model_registry(
|
||||||
|
available_models: Dict[str, List[ProviderModelEntry]],
|
||||||
|
) -> List[ModelInput]:
|
||||||
models = []
|
models = []
|
||||||
for provider_id, entries in available_models.items():
|
for provider_id, entries in available_models.items():
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
|
@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel):
|
||||||
default_models.append(
|
default_models.append(
|
||||||
DefaultModel(
|
DefaultModel(
|
||||||
model_id=model_entry.provider_model_id,
|
model_id=model_entry.provider_model_id,
|
||||||
doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "",
|
doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue