mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-23 02:03:54 +00:00
docs: update the docs for NVIDIA Inference provider (#3227)
# What does this PR do? - Documentation update and fix for the NVIDIA Inference provider. - Update the `run_moderation` for safety API with a `NotImplementedError` placeholder. Otherwise initialization NVIDIA inference client will raise an error. ## Test Plan N/A
This commit is contained in:
parent
1790fc0f25
commit
b72169ca47
2 changed files with 76 additions and 1 deletions
|
@ -41,6 +41,11 @@ client.initialize()
|
||||||
|
|
||||||
### Create Completion
|
### Create Completion
|
||||||
|
|
||||||
|
> Note on Completion API
|
||||||
|
>
|
||||||
|
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = client.inference.completion(
|
response = client.inference.completion(
|
||||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
@ -76,6 +81,73 @@ response = client.inference.chat_completion(
|
||||||
print(f"Response: {response.completion_message.content}")
|
print(f"Response: {response.completion_message.content}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Tool Calling Example ###
|
||||||
|
```python
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
|
|
||||||
|
tool_definition = ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get current weather information for a location",
|
||||||
|
parameters={
|
||||||
|
"location": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The city and state, e.g. San Francisco, CA",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"unit": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="Temperature unit (celsius or fahrenheit)",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_response = client.inference.chat_completion(
|
||||||
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||||
|
tools=[tool_definition],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Tool Response: {tool_response.completion_message.content}")
|
||||||
|
if tool_response.completion_message.tool_calls:
|
||||||
|
for tool_call in tool_response.completion_message.tool_calls:
|
||||||
|
print(f"Tool Called: {tool_call.tool_name}")
|
||||||
|
print(f"Arguments: {tool_call.arguments}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Structured Output Example
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||||
|
|
||||||
|
person_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "integer"},
|
||||||
|
"occupation": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["name", "age", "occupation"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response_format = JsonSchemaResponseFormat(
|
||||||
|
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
structured_response = client.inference.chat_completion(
|
||||||
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
response_format=response_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Structured Response: {structured_response.completion_message.content}")
|
||||||
|
```
|
||||||
|
|
||||||
### Create Embeddings
|
### Create Embeddings
|
||||||
> Note on OpenAI embeddings compatibility
|
> Note on OpenAI embeddings compatibility
|
||||||
>
|
>
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||||
return await self.shield.run(messages)
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
|
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||||
|
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
||||||
|
|
||||||
|
|
||||||
class NeMoGuardrails:
|
class NeMoGuardrails:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue