mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-29 17:49:24 +00:00
Merge branch 'main' into test-modelregistryhelper
This commit is contained in:
commit
7fd8a61b4d
80 changed files with 2918 additions and 386 deletions
5
llama_stack/providers/remote/eval/__init__.py
Normal file
5
llama_stack/providers/remote/eval/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
# NVIDIA NeMo Evaluator Eval Provider
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
|
||||
|
||||
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
|
||||
|
||||
### Example for register an academic benchmark
|
||||
|
||||
```
|
||||
POST /eval/benchmarks
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "mmlu",
|
||||
"dataset_id": "",
|
||||
"scoring_functions": [],
|
||||
"metadata": {
|
||||
"type": "mmlu"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example for register a custom evaluation
|
||||
|
||||
```
|
||||
POST /eval/benchmarks
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "my-custom-benchmark",
|
||||
"dataset_id": "",
|
||||
"scoring_functions": [],
|
||||
"metadata": {
|
||||
"type": "custom",
|
||||
"params": {
|
||||
"parallelism": 8
|
||||
},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {
|
||||
"template": {
|
||||
"prompt": "{{prompt}}",
|
||||
"max_tokens": 200
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
|
||||
},
|
||||
"metrics": {
|
||||
"bleu": {
|
||||
"type": "bleu",
|
||||
"params": {
|
||||
"references": [
|
||||
"{{ideal_response}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example for triggering a benchmark/custom evaluation
|
||||
|
||||
```
|
||||
POST /eval/benchmarks/{benchmark_id}/jobs
|
||||
```
|
||||
```json
|
||||
{
|
||||
"benchmark_id": "my-custom-benchmark",
|
||||
"benchmark_config": {
|
||||
"eval_candidate": {
|
||||
"type": "model",
|
||||
"model": "meta-llama/Llama3.1-8B-Instruct",
|
||||
"sampling_params": {
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.7
|
||||
}
|
||||
},
|
||||
"scoring_params": {}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Response example:
|
||||
```json
|
||||
{
|
||||
"job_id": "eval-1234",
|
||||
"status": "in_progress"
|
||||
}
|
||||
```
|
||||
|
||||
### Example for getting the status of a job
|
||||
```
|
||||
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||
```
|
||||
|
||||
Response example:
|
||||
```json
|
||||
{
|
||||
"job_id": "eval-1234",
|
||||
"status": "in_progress"
|
||||
}
|
||||
```
|
||||
|
||||
### Example for cancelling a job
|
||||
```
|
||||
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
|
||||
```
|
||||
|
||||
### Example for getting the results
|
||||
```
|
||||
GET /eval/benchmarks/{benchmark_id}/results
|
||||
```
|
||||
```json
|
||||
{
|
||||
"generations": [],
|
||||
"scores": {
|
||||
"{benchmark_id}": {
|
||||
"score_rows": [],
|
||||
"aggregated_results": {
|
||||
"tasks": {},
|
||||
"groups": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import NVIDIAEvalConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: NVIDIAEvalConfig,
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import NVIDIAEvalImpl
|
||||
|
||||
impl = NVIDIAEvalImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
deps[Api.scoring],
|
||||
deps[Api.inference],
|
||||
deps[Api.agents],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]
|
||||
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NVIDIAEvalConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
|
||||
"""
|
||||
|
||||
evaluator_url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the evaluator service",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
|
||||
}
|
||||
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import NVIDIAEvalConfig
|
||||
|
||||
DEFAULT_NAMESPACE = "nvidia"
|
||||
|
||||
|
||||
class NVIDIAEvalImpl(
|
||||
Eval,
|
||||
BenchmarksProtocolPrivate,
|
||||
ModelRegistryHelper,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: NVIDIAEvalConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
scoring_api: Scoring,
|
||||
inference_api: Inference,
|
||||
agents_api: Agents,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
self.scoring_api = scoring_api
|
||||
self.inference_api = inference_api
|
||||
self.agents_api = agents_api
|
||||
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def _evaluator_get(self, path):
|
||||
"""Helper for making GET requests to the evaluator service."""
|
||||
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_post(self, path, data):
|
||||
"""Helper for making POST requests to the evaluator service."""
|
||||
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
"""Register a benchmark as an evaluation configuration."""
|
||||
await self._evaluator_post(
|
||||
"/v1/evaluation/configs",
|
||||
{
|
||||
"namespace": DEFAULT_NAMESPACE,
|
||||
"name": task_def.benchmark_id,
|
||||
# metadata is copied to request body as-is
|
||||
**task_def.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation job for a benchmark."""
|
||||
model = (
|
||||
benchmark_config.eval_candidate.model
|
||||
if benchmark_config.eval_candidate.type == "model"
|
||||
else benchmark_config.eval_candidate.config.model
|
||||
)
|
||||
nvidia_model = self.get_provider_model_id(model) or model
|
||||
|
||||
result = await self._evaluator_post(
|
||||
"/v1/evaluation/jobs",
|
||||
{
|
||||
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
|
||||
"target": {"type": "model", "model": nvidia_model},
|
||||
},
|
||||
)
|
||||
|
||||
return Job(job_id=result["id"], status=JobStatus.in_progress)
|
||||
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of an evaluation job.
|
||||
|
||||
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
|
||||
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
|
||||
"""
|
||||
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
|
||||
result_status = result["status"]
|
||||
|
||||
job_status = JobStatus.failed
|
||||
if result_status in ["created", "pending"]:
|
||||
job_status = JobStatus.scheduled
|
||||
elif result_status in ["running"]:
|
||||
job_status = JobStatus.in_progress
|
||||
elif result_status in ["completed"]:
|
||||
job_status = JobStatus.completed
|
||||
elif result_status in ["cancelled"]:
|
||||
job_status = JobStatus.cancelled
|
||||
|
||||
return Job(job_id=job_id, status=job_status)
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel the evaluation job."""
|
||||
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Returns the results of the evaluation job."""
|
||||
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
|
||||
|
||||
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
|
||||
|
||||
return EvaluateResponse(
|
||||
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
|
||||
generations=[],
|
||||
scores={
|
||||
benchmark_id: ScoringResult(
|
||||
score_rows=[],
|
||||
aggregated_results=result,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
|
@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
|
|||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
append_api_version: bool = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
|||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
|
@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1"
|
||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
|
|
@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Allow non-llama model registration.
|
||||
|
||||
Non-llama model registration: API Catalogue models, post-training models, etc.
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.models.register(
|
||||
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
model_type=ModelType.llm,
|
||||
provider_id="nvidia",
|
||||
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
||||
)
|
||||
|
||||
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
||||
"""
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
# not llama model
|
||||
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
else:
|
||||
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -73,6 +72,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
from ollama import AsyncClient # type: ignore[attr-defined]
|
||||
|
||||
from .models import model_entries
|
||||
|
||||
|
|
|
|||
|
|
@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
# Together client has no close method, so just set to None
|
||||
self._client = None
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", True):
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
def _lazy_initialize_client(self):
|
||||
if self.client is not None:
|
||||
return
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self):
|
||||
return AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
|
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
|
@ -357,12 +368,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
assert self.client is not None
|
||||
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
||||
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||
# Changing this may lead to unpredictable behavior.
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
res = await self.client.models.list()
|
||||
res = await client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
|
|
@ -413,6 +427,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
|
|
@ -452,6 +467,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
extra_body: Dict[str, Any] = {}
|
||||
|
|
@ -508,6 +524,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
|
|||
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import WatsonXConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
from .watsonx import WatsonXInferenceAdapter
|
||||
|
||||
if not isinstance(config, WatsonXConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = WatsonXInferenceAdapter(config)
|
||||
return adapter
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
||||
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
url: str
|
||||
api_key: str
|
||||
project_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WatsonXConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||
description="The watsonx API key, only needed of using the hosted service",
|
||||
)
|
||||
project_id: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||
description="The Project ID key, only needed of using the hosted service",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
|
||||
"api_key": "${env.WATSONX_API_KEY:}",
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:}",
|
||||
}
|
||||
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-2-13b-chat",
|
||||
CoreModelId.llama2_13b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
378
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
378
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
|
|
@ -0,0 +1,378 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from ibm_watson_machine_learning.foundation_models import Model
|
||||
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
GreedySamplingStrategy,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from . import WatsonXConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: WatsonXConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
|
||||
print(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||
|
||||
self._config = config
|
||||
|
||||
self._project_id = self._config.project_id
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self, model_id) -> Model:
|
||||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||
config_url = self._config.url
|
||||
project_id = self._config.project_id
|
||||
credentials = {"url": config_url, "apikey": config_api_key}
|
||||
|
||||
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/openai/v1",
|
||||
api_key=self._config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client(request.model).generate(**params)
|
||||
choices = []
|
||||
if "results" in r:
|
||||
for result in r["results"]:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||
text=result["generated_text"],
|
||||
)
|
||||
choices.append(choice)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=choices,
|
||||
)
|
||||
return process_completion_response(response)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = self._get_client(request.model).generate_text_stream(**params)
|
||||
for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=None,
|
||||
text=chunk,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client(request.model).generate(**params)
|
||||
choices = []
|
||||
if "results" in r:
|
||||
for result in r["results"]:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||
text=result["generated_text"],
|
||||
)
|
||||
choices.append(choice)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=choices,
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
model_id = request.model
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client(model_id).generate_text_stream(**params)
|
||||
for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=None,
|
||||
text=chunk,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||
input_dict = {"params": {}}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
if request.sampling_params:
|
||||
if request.sampling_params.strategy:
|
||||
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||
if request.sampling_params.max_tokens:
|
||||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||
if request.sampling_params.repetition_penalty:
|
||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||
|
||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||
|
||||
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
||||
|
||||
params = {
|
||||
**input_dict,
|
||||
}
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError("embedding is not supported for watsonx")
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# watsonx.ai sometimes adds usage data to the stream
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
|
|
@ -36,7 +36,6 @@ import os
|
|||
|
||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
|
||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||
|
|
@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id")
|
|||
|
||||
### Inference with the fine-tuned model
|
||||
|
||||
#### 1. Register the model
|
||||
|
||||
```python
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
|
||||
client.models.register(
|
||||
model_id="test-example-model@v1",
|
||||
provider_id="nvidia",
|
||||
provider_model_id="test-example-model@v1",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
```
|
||||
|
||||
#### 2. Inference with the fine-tuned model
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
|
|
|
|||
|
|
@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||
# TODO: filter by available models based on /config endpoint
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||
self.customizer_url = config.customizer_url
|
||||
self.session = None
|
||||
|
||||
self.customizer_url = config.customizer_url
|
||||
if not self.customizer_url:
|
||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||
self.customizer_url = "http://nemo.test"
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||
return self.session
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
|
|
@ -94,8 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
if json and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
session = await self._get_session()
|
||||
for _ in range(self.config.max_retries):
|
||||
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"API request failed: {error_data}")
|
||||
|
|
@ -122,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
jobs = []
|
||||
for job in response.get("data", []):
|
||||
job_id = job.pop("id")
|
||||
job_status = job.pop("status", "unknown").lower()
|
||||
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
||||
job_status = job.pop("status", "scheduled").lower()
|
||||
mapped_status = STATUS_MAPPING.get(job_status, "scheduled")
|
||||
|
||||
# Convert string timestamps to datetime objects
|
||||
created_at = (
|
||||
|
|
@ -177,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
)
|
||||
|
||||
api_status = response.pop("status").lower()
|
||||
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
||||
mapped_status = STATUS_MAPPING.get(api_status, "scheduled")
|
||||
|
||||
return NvidiaPostTrainingJobStatusResponse(
|
||||
status=JobStatus(mapped_status),
|
||||
|
|
@ -239,6 +245,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
Supported models:
|
||||
- meta/llama-3.1-8b-instruct
|
||||
- meta/llama-3.2-1b-instruct
|
||||
|
||||
Supported algorithm configs:
|
||||
- LoRA, SFT
|
||||
|
|
@ -284,10 +291,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
- LoRA config:
|
||||
## NeMo customizer specific LoRA parameters
|
||||
- adapter_dim: int - Adapter dimension
|
||||
Default: 8 (supports powers of 2)
|
||||
- adapter_dropout: float - Adapter dropout
|
||||
Default: None (0.0-1.0)
|
||||
- alpha: int - Scaling factor for the LoRA update
|
||||
Default: 16
|
||||
Note:
|
||||
|
|
@ -297,7 +300,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
User is informed about unsupported parameters via warnings.
|
||||
"""
|
||||
# Map model to nvidia model name
|
||||
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
|
||||
# See `_MODEL_ENTRIES` for supported models
|
||||
nvidia_model = self.get_provider_model_id(model)
|
||||
|
||||
# Check for unsupported method parameters
|
||||
|
|
@ -330,7 +333,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
},
|
||||
"data_config": {"dataset_id", "batch_size"},
|
||||
"optimizer_config": {"lr", "weight_decay"},
|
||||
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
|
||||
"lora_config": {"type", "alpha"},
|
||||
}
|
||||
|
||||
# Validate all parameters at once
|
||||
|
|
@ -389,16 +392,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
# Handle LoRA-specific configuration
|
||||
if algorithm_config:
|
||||
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
|
||||
if algorithm_config.type == "LoRA":
|
||||
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
||||
job_config["hyperparameters"]["lora"] = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"adapter_dim": algorithm_config.get("adapter_dim"),
|
||||
"alpha": algorithm_config.get("alpha"),
|
||||
"adapter_dropout": algorithm_config.get("adapter_dropout"),
|
||||
}.items()
|
||||
if v is not None
|
||||
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue