forked from phoenix-oss/llama-stack-mirror
fix nvidia inference provider (#781)
# What does this PR do? - fixes to nvidia inference provider to account for strategy update - update nvidia templates ## Test Plan ``` llama stack run ./llama_stack/templates/nvidia/run.yaml --port 5000 LLAMA_STACK_BASE_URL="http://localhost:5000" pytest -v tests/client-sdk/inference/test_inference.py --html=report.html --self-contained-html ``` <img width="1288" alt="image" src="https://github.com/user-attachments/assets/d20f9aea-525e-47de-a5be-586e022e0d55" /> **NOTE** - vision inference broken - tool calling broken - /completion broken cc @mattf @cdgamarose-nv for improving NVIDIA inference adapter ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
965644ce68
commit
b76bef169c
5 changed files with 351 additions and 262 deletions
|
@ -6,8 +6,11 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES
|
||||
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
@ -36,10 +39,17 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=NVIDIAConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="nvidia",
|
||||
)
|
||||
core_model_to_hf_repo = {
|
||||
m.descriptor(): m.huggingface_repo for m in all_registered_models()
|
||||
}
|
||||
default_models = [
|
||||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model],
|
||||
provider_model_id=m.provider_model_id,
|
||||
provider_id="nvidia",
|
||||
)
|
||||
for m in _MODEL_ALIASES
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name="nvidia",
|
||||
|
@ -48,13 +58,13 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
docker_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
default_models=[inference_model],
|
||||
default_models=default_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
},
|
||||
default_models=[inference_model],
|
||||
default_models=default_models,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue