Merge branch 'main' into fix/nvidia-launch-customization

This commit is contained in:
Jash Gulabrai 2025-04-25 16:01:22 -04:00
commit 6659ed995a
53 changed files with 2203 additions and 217 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,134 @@
# NVIDIA NeMo Evaluator Eval Provider
## Overview
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
### Example for register an academic benchmark
```
POST /eval/benchmarks
```
```json
{
"benchmark_id": "mmlu",
"dataset_id": "",
"scoring_functions": [],
"metadata": {
"type": "mmlu"
}
}
```
### Example for register a custom evaluation
```
POST /eval/benchmarks
```
```json
{
"benchmark_id": "my-custom-benchmark",
"dataset_id": "",
"scoring_functions": [],
"metadata": {
"type": "custom",
"params": {
"parallelism": 8
},
"tasks": {
"qa": {
"type": "completion",
"params": {
"template": {
"prompt": "{{prompt}}",
"max_tokens": 200
}
},
"dataset": {
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
},
"metrics": {
"bleu": {
"type": "bleu",
"params": {
"references": [
"{{ideal_response}}"
]
}
}
}
}
}
}
}
```
### Example for triggering a benchmark/custom evaluation
```
POST /eval/benchmarks/{benchmark_id}/jobs
```
```json
{
"benchmark_id": "my-custom-benchmark",
"benchmark_config": {
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama3.1-8B-Instruct",
"sampling_params": {
"max_tokens": 100,
"temperature": 0.7
}
},
"scoring_params": {}
}
}
```
Response example:
```json
{
"job_id": "eval-1234",
"status": "in_progress"
}
```
### Example for getting the status of a job
```
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
```
Response example:
```json
{
"job_id": "eval-1234",
"status": "in_progress"
}
```
### Example for cancelling a job
```
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
```
### Example for getting the results
```
GET /eval/benchmarks/{benchmark_id}/results
```
```json
{
"generations": [],
"scores": {
"{benchmark_id}": {
"score_rows": [],
"aggregated_results": {
"tasks": {},
"groups": {}
}
}
}
}
```

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import NVIDIAEvalConfig
async def get_adapter_impl(
config: NVIDIAEvalConfig,
deps: Dict[Api, Any],
):
from .eval import NVIDIAEvalImpl
impl = NVIDIAEvalImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
deps[Api.agents],
)
await impl.initialize()
return impl
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict
from pydantic import BaseModel, Field
class NVIDIAEvalConfig(BaseModel):
"""
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
Attributes:
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
"""
evaluator_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
description="The url for accessing the evaluator service",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
}

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import requests
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring, ScoringResult
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import NVIDIAEvalConfig
DEFAULT_NAMESPACE = "nvidia"
class NVIDIAEvalImpl(
Eval,
BenchmarksProtocolPrivate,
ModelRegistryHelper,
):
def __init__(
self,
config: NVIDIAEvalConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
agents_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
self.agents_api = agents_api
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def _evaluator_get(self, path):
"""Helper for making GET requests to the evaluator service."""
response = requests.get(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
return response.json()
async def _evaluator_post(self, path, data):
"""Helper for making POST requests to the evaluator service."""
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
response.raise_for_status()
return response.json()
async def register_benchmark(self, task_def: Benchmark) -> None:
"""Register a benchmark as an evaluation configuration."""
await self._evaluator_post(
"/v1/evaluation/configs",
{
"namespace": DEFAULT_NAMESPACE,
"name": task_def.benchmark_id,
# metadata is copied to request body as-is
**task_def.metadata,
},
)
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation job for a benchmark."""
model = (
benchmark_config.eval_candidate.model
if benchmark_config.eval_candidate.type == "model"
else benchmark_config.eval_candidate.config.model
)
nvidia_model = self.get_provider_model_id(model) or model
result = await self._evaluator_post(
"/v1/evaluation/jobs",
{
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
"target": {"type": "model", "model": nvidia_model},
},
)
return Job(job_id=result["id"], status=JobStatus.in_progress)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
raise NotImplementedError()
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of an evaluation job.
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
"""
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
result_status = result["status"]
job_status = JobStatus.failed
if result_status in ["created", "pending"]:
job_status = JobStatus.scheduled
elif result_status in ["running"]:
job_status = JobStatus.in_progress
elif result_status in ["completed"]:
job_status = JobStatus.completed
elif result_status in ["cancelled"]:
job_status = JobStatus.cancelled
return Job(job_id=job_id, status=job_status)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel the evaluation job."""
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Returns the results of the evaluation job."""
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
return EvaluateResponse(
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
generations=[],
scores={
benchmark_id: ScoringResult(
score_rows=[],
aggregated_results=result,
)
},
)

View file

@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
}

View file

@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1"
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def _get_provider_model_id(self, model_id: str) -> str:
@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
return await self._get_client(provider_model_id).chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def register_model(self, model: Model) -> Model:
"""
Allow non-llama model registration.
Non-llama model registration: API Catalogue models, post-training models, etc.
client = LlamaStackAsLibraryClient("nvidia")
client.models.register(
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
model_type=ModelType.llm,
provider_id="nvidia",
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
)
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
# not llama model
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
else:
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
return model

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
from .watsonx import WatsonXInferenceAdapter
if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config)
return adapter
__all__ = ["get_adapter_impl", "WatsonXConfig"]

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
url: str
api_key: str
project_id: str
@json_schema_type
class WatsonXConfig(BaseModel):
url: str = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
api_key: Optional[SecretStr] = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key, only needed of using the hosted service",
)
project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:}",
"project_id": "${env.WATSONX_PROJECT_ID:}",
}

View file

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-2-13b-chat",
CoreModelId.llama2_13b.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]

View file

@ -0,0 +1,260 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
request_has_media,
)
from . import WatsonXConfig
from .models import MODEL_ENTRIES
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
print(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._project_id = self._config.project_id
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
text=result["generated_text"],
)
choices.append(choice)
response = OpenAICompatCompletionResponse(
choices=choices,
)
return process_completion_response(response)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = self._get_client(request.model).generate_text_stream(**params)
for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=None,
text=chunk,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream):
yield chunk
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
text=result["generated_text"],
)
choices.append(choice)
response = OpenAICompatCompletionResponse(
choices=choices,
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
model_id = request.model
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client(model_id).generate_text_stream(**params)
for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=None,
text=chunk,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if request.sampling_params.additional_params.get("top_p"):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"]
if request.sampling_params.additional_params.get("top_k"):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"]
if request.sampling_params.additional_params.get("temperature"):
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
if request.sampling_params.additional_params.get("length_penalty"):
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
"length_penalty"
]
if request.sampling_params.additional_params.get("random_seed"):
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
if request.sampling_params.additional_params.get("min_new_tokens"):
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
"min_new_tokens"
]
if request.sampling_params.additional_params.get("stop_sequences"):
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
"stop_sequences"
]
if request.sampling_params.additional_params.get("time_limit"):
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
if request.sampling_params.additional_params.get("truncate_input_tokens"):
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
"truncate_input_tokens"
]
if request.sampling_params.additional_params.get("return_options"):
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
"return_options"
]
params = {
**input_dict,
}
return params
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
pass

View file

@ -36,7 +36,6 @@ import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id")
### Inference with the fine-tuned model
#### 1. Register the model
```python
from llama_stack.apis.models import Model, ModelType
client.models.register(
model_id="test-example-model@v1",
provider_id="nvidia",
provider_model_id="test-example-model@v1",
model_type=ModelType.llm,
)
```
#### 2. Inference with the fine-tuned model
```python
response = client.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",