mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 03:59:42 +00:00
Clean up instructions and implementation; reorganize notebooks
This commit is contained in:
parent
0d9d333a4e
commit
4131e8146f
29 changed files with 2756 additions and 89 deletions
|
@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
|
|||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
append_api_version: bool = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
|
||||
}
|
||||
|
|
|
@ -42,10 +42,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
|
@ -126,15 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
# add /v1 in case of hosted models
|
||||
base_url = self._config.url
|
||||
if _is_nvidia_hosted(self._config):
|
||||
if provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
else:
|
||||
base_url = f"{self._config.url}/v1"
|
||||
elif "nim.int.aire.nvidia.com" in base_url:
|
||||
base_url = f"{base_url}/v1"
|
||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def completion(
|
||||
|
@ -258,9 +250,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
print(f"provider_model_id: {provider_model_id}")
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=provider_model_id,
|
||||
model=self.get_provider_model_id(model_id),
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue