mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 15:34:30 +00:00
Merge branch 'main' of https://github.com/meta-llama/llama-stack into add_nemo_customizer
This commit is contained in:
commit
f534b4c2ea
571 changed files with 229651 additions and 12956 deletions
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleAgentsImpl(Agents):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "huggingface_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to HF storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="huggingface_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .config import HuggingfaceDatasetIOConfig
|
||||
|
|
@ -18,22 +18,14 @@ from .config import HuggingfaceDatasetIOConfig
|
|||
DATASETS_PREFIX = "datasets:"
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_def: Dataset):
|
||||
if dataset_def.metadata.get("path", None):
|
||||
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
|
||||
else:
|
||||
df = get_dataframe_from_url(dataset_def.url)
|
||||
def parse_hf_params(dataset_def: Dataset):
|
||||
uri = dataset_def.source.uri
|
||||
parsed_uri = urlparse(uri)
|
||||
params = parse_qs(parsed_uri.query)
|
||||
params = {k: v[0] for k, v in params.items()}
|
||||
path = parsed_uri.path.lstrip("/")
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||
|
||||
dataset = hf_datasets.Dataset.from_pandas(df)
|
||||
|
||||
# drop columns not specified by schema
|
||||
if dataset_def.dataset_schema:
|
||||
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
|
||||
|
||||
return dataset
|
||||
return path, params
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
|
|
@ -64,7 +56,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=dataset_def.json(),
|
||||
value=dataset_def.model_dump_json(),
|
||||
)
|
||||
self.dataset_infos[dataset_def.identifier] = dataset_def
|
||||
|
||||
|
|
@ -73,41 +65,34 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
await self.kvstore.delete(key=key)
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> IterrowsResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
start_index = start_index or 0
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
|
||||
start = next_page_token
|
||||
if rows_in_page == -1:
|
||||
if limit is None or limit == -1:
|
||||
end = len(loaded_dataset)
|
||||
else:
|
||||
end = min(start + rows_in_page, len(loaded_dataset))
|
||||
end = min(start_index + limit, len(loaded_dataset))
|
||||
|
||||
rows = [loaded_dataset[i] for i in range(start, end)]
|
||||
rows = [loaded_dataset[i] for i in range(start_index, end)]
|
||||
|
||||
return PaginatedRowsResult(
|
||||
rows=rows,
|
||||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_start_index=end if end < len(loaded_dataset) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
||||
# Convert rows to HF Dataset format
|
||||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||
|
|
|
|||
23
llama_stack/providers/remote/inference/anthropic/__init__.py
Normal file
23
llama_stack/providers/remote/inference/anthropic/__init__.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: Optional[str] = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
impl = AnthropicInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
32
llama_stack/providers/remote/inference/anthropic/config.py
Normal file
32
llama_stack/providers/remote/inference/anthropic/config.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
35
llama_stack/providers/remote/inference/anthropic/models.py
Normal file
35
llama_stack/providers/remote/inference/anthropic/models.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"anthropic/claude-3-5-sonnet-latest",
|
||||
"anthropic/claude-3-7-sonnet-latest",
|
||||
"anthropic/claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
|
|
@ -72,7 +72,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -83,7 +83,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -92,6 +92,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -46,14 +46,14 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
from .models import model_entries
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=model_entries,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
|
|
@ -72,11 +72,13 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -112,7 +114,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -121,6 +123,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
model_entries = [
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -20,3 +21,15 @@ class DatabricksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
|
@ -70,7 +71,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -81,7 +82,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -90,6 +91,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ class FireworksImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
|
@ -54,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
|
|
@ -67,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
pass
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key is not None:
|
||||
return self.config.api_key.get_secret_value()
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
|
|
@ -85,11 +89,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -156,7 +162,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -165,6 +171,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -226,12 +234,15 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
return {
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logger.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -24,10 +24,6 @@ MODEL_ENTRIES = [
|
|||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
|
|
|
|||
23
llama_stack/providers/remote/inference/gemini/__init__.py
Normal file
23
llama_stack/providers/remote/inference/gemini/__init__.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: Optional[str] = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||
from .gemini import GeminiInferenceAdapter
|
||||
|
||||
impl = GeminiInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
32
llama_stack/providers/remote/inference/gemini/config.py
Normal file
32
llama_stack/providers/remote/inference/gemini/config.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
27
llama_stack/providers/remote/inference/gemini/gemini.py
Normal file
27
llama_stack/providers/remote/inference/gemini/gemini.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: GeminiConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
24
llama_stack/providers/remote/inference/gemini/models.py
Normal file
24
llama_stack/providers/remote/inference/gemini/models.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
]
|
||||
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
|
|
@ -4,23 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import GroqConfig
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
if not isinstance(config, GroqConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
|
||||
adapter = GroqInferenceAdapter(config)
|
||||
return adapter
|
||||
|
|
|
|||
|
|
@ -4,13 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
|
|
@ -18,3 +25,15 @@ class GroqConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
url: str = Field(
|
||||
default="https://api.groq.com",
|
||||
description="The URL for the Groq AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,163 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import groq
|
||||
from groq import Groq
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
SamplingParams,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import CoreModelId
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
convert_chat_completion_response_stream,
|
||||
)
|
||||
|
||||
_MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
# Preview models aren't recommended for production use, but we include this one
|
||||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
]
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
|
||||
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: GroqConfig
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
# Groq doesn't support non-chat completion as of time of writing
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
model_id = self.get_provider_model_id(model_id)
|
||||
if model_id == "llama-3.2-3b-preview":
|
||||
warnings.warn(
|
||||
"Groq only contains a preview version for llama-3.2-3b-instruct. "
|
||||
"Preview models aren't recommended for production use. "
|
||||
"They can be discontinued on short notice."
|
||||
"More details: https://console.groq.com/docs/models"
|
||||
)
|
||||
|
||||
request = convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
try:
|
||||
response = self._get_client().chat.completions.create(**request)
|
||||
except groq.BadRequestError as e:
|
||||
if e.body.get("error", {}).get("code") == "tool_use_failed":
|
||||
# For smaller models, Groq may fail to call a tool even when the request is well formed
|
||||
raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e
|
||||
else:
|
||||
raise e
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
if stream:
|
||||
return convert_chat_completion_response_stream(response)
|
||||
else:
|
||||
return convert_chat_completion_response(response)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_client(self) -> Groq:
|
||||
if self._config.api_key is not None:
|
||||
return Groq(api_key=self._config.api_key)
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.groq_api_key:
|
||||
raise ValueError(
|
||||
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
|
||||
)
|
||||
return Groq(api_key=provider_data.groq_api_key)
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
|
|
|||
|
|
@ -1,245 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Literal
|
||||
|
||||
from groq import Stream
|
||||
from groq.types.chat.chat_completion import ChatCompletion
|
||||
from groq.types.chat.chat_completion_assistant_message_param import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||
from groq.types.chat.chat_completion_system_message_param import (
|
||||
ChatCompletionSystemMessageParam,
|
||||
)
|
||||
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
||||
from groq.types.chat.chat_completion_user_message_param import (
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
UnparseableToolCall,
|
||||
convert_tool_call,
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
) -> CompletionCreateParams:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
|
||||
Warns client if request contains unsupported features.
|
||||
"""
|
||||
|
||||
if request.logprobs:
|
||||
# Groq doesn't support logprobs at the time of writing
|
||||
warnings.warn("logprobs are not supported yet")
|
||||
|
||||
if request.response_format:
|
||||
# Groq's JSON mode is beta at the time of writing
|
||||
warnings.warn("response_format is not supported yet")
|
||||
|
||||
if request.sampling_params.repetition_penalty != 1.0:
|
||||
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
|
||||
# seem to have different semantics
|
||||
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
|
||||
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
|
||||
# so we exclude it for now
|
||||
warnings.warn("repetition_penalty is not supported")
|
||||
|
||||
if request.tool_config.tool_prompt_format != ToolPromptFormat.json:
|
||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||
|
||||
sampling_options = get_sampling_strategy_options(request.sampling_params)
|
||||
return CompletionCreateParams(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
logprobs=None,
|
||||
frequency_penalty=None,
|
||||
stream=request.stream,
|
||||
max_tokens=request.sampling_params.max_tokens or None,
|
||||
temperature=sampling_options.get("temperature", 1.0),
|
||||
top_p=sampling_options.get("top_p", 1.0),
|
||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||
tool_choice=(request.tool_config.tool_choice.value if request.tool_config.tool_choice else None),
|
||||
)
|
||||
|
||||
|
||||
def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||
if message.role == "system":
|
||||
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
||||
elif message.role == "user":
|
||||
return ChatCompletionUserMessageParam(role="user", content=message.content)
|
||||
elif message.role == "assistant":
|
||||
return ChatCompletionAssistantMessageParam(role="assistant", content=message.content)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
|
||||
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
||||
# Groq requires a description for function tools
|
||||
if tool_definition.description is None:
|
||||
raise AssertionError("tool_definition.description is required")
|
||||
|
||||
tool_parameters = tool_definition.parameters or {}
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=tool_definition.tool_name,
|
||||
description=tool_definition.description,
|
||||
parameters={key: _convert_groq_tool_parameter(param) for key, param in tool_parameters.items()},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
|
||||
param = {
|
||||
"type": tool_parameter.param_type,
|
||||
}
|
||||
if tool_parameter.description is not None:
|
||||
param["description"] = tool_parameter.description
|
||||
if tool_parameter.required is not None:
|
||||
param["required"] = tool_parameter.required
|
||||
if tool_parameter.default is not None:
|
||||
param["default"] = tool_parameter.default
|
||||
return param
|
||||
|
||||
|
||||
def convert_chat_completion_response(
|
||||
response: ChatCompletion,
|
||||
) -> ChatCompletionResponse:
|
||||
# groq only supports n=1 at time of writing, so there is only one choice
|
||||
choice = response.choices[0]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
# Otherwise, return tool calls as normal
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
stop_reason=StopReason.end_of_message,
|
||||
# Content is not optional
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(
|
||||
finish_reason: Literal["stop", "length", "tool_calls"],
|
||||
) -> StopReason:
|
||||
"""
|
||||
Convert a Groq chat completion finish_reason to a StopReason.
|
||||
|
||||
finish_reason: Literal["stop", "length", "tool_calls"]
|
||||
- stop -> model hit a natural stop point or a provided stop sequence
|
||||
- length -> maximum number of tokens specified in the request was reached
|
||||
- tool_calls -> model called a tool
|
||||
"""
|
||||
if finish_reason == "stop":
|
||||
return StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
return StopReason.out_of_tokens
|
||||
elif finish_reason == "tool_calls":
|
||||
return StopReason.end_of_message
|
||||
else:
|
||||
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
||||
|
||||
|
||||
async def convert_chat_completion_response_stream(
|
||||
stream: Stream[ChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
event_type = ChatCompletionResponseEventType.start
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
|
||||
if choice.finish_reason:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
)
|
||||
)
|
||||
elif choice.delta.tool_calls:
|
||||
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
|
||||
|
||||
# We assume Groq produces fully formed tool calls for each chunk
|
||||
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||
if isinstance(tool_call, ToolCall):
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Otherwise it's an UnparseableToolCall - return the raw tool call
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call.model_dump_json(),
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=None,
|
||||
)
|
||||
)
|
||||
event_type = ChatCompletionResponseEventType.progress
|
||||
38
llama_stack/providers/remote/inference/groq/models.py
Normal file
38
llama_stack/providers/remote/inference/groq/models.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_list import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"groq/llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
# Preview models aren't recommended for production use, but we include this one
|
||||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
]
|
||||
|
|
@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
_MODEL_ENTRIES = [
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_8b_instruct.value,
|
||||
|
|
@ -48,12 +48,49 @@ _MODEL_ENTRIES = [
|
|||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
# NeMo Retriever Text Embedding models -
|
||||
#
|
||||
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
#
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
# | Model ID | Max | Publisher | Embedding | Dynamic |
|
||||
# | | Tokens | | Dimension | Embeddings |
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
# | nvidia/llama-3.2-nv-embedqa-1b-v2 | 8192 | NVIDIA | 2048 | Yes |
|
||||
# | nvidia/nv-embedqa-e5-v5 | 512 | NVIDIA | 1024 | No |
|
||||
# | nvidia/nv-embedqa-mistral-7b-v2 | 512 | NVIDIA | 4096 | No |
|
||||
# | snowflake/arctic-embed-l | 512 | Snowflake | 1024 | No |
|
||||
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||
ProviderModelEntry(
|
||||
provider_model_id="baai/bge-m3",
|
||||
provider_model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 2048,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/nv-embedqa-e5-v5",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 8192,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/nv-embedqa-mistral-7b-v2",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 4096,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="snowflake/arctic-embed-l",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@
|
|||
|
||||
import logging
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -40,15 +41,17 @@ from llama_stack.models.llama.datatypes import (
|
|||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .models import _MODEL_ENTRIES
|
||||
from .models import MODEL_ENTRIES
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_openai_completion_choice,
|
||||
convert_openai_completion_stream,
|
||||
)
|
||||
|
|
@ -60,7 +63,7 @@ logger = logging.getLogger(__name__)
|
|||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
|
|
@ -80,22 +83,54 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# )
|
||||
|
||||
self._config = config
|
||||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/v1",
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
||||
"""
|
||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
||||
some models are hosted on different URLs. This function returns the appropriate client
|
||||
for the given provider_model_id.
|
||||
|
||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
||||
|
||||
:param provider_model_id: The provider model ID
|
||||
:return: An OpenAI client
|
||||
"""
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
||||
"""
|
||||
Maintain a single OpenAI client per base_url.
|
||||
"""
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
special_model_urls = {
|
||||
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1"
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if content_has_media(content):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
|
|
@ -103,9 +138,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = convert_completion_request(
|
||||
request=CompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
model=provider_model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -116,7 +152,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.completions.create(**request)
|
||||
response = await self._get_client(provider_model_id).completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -144,19 +180,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
#
|
||||
# we can ignore str and always pass List[str] to OpenAI
|
||||
#
|
||||
flat_contents = [
|
||||
item.text if isinstance(item, TextContentItem) else item
|
||||
for content in contents
|
||||
for item in (content if isinstance(content, list) else [content])
|
||||
]
|
||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
model = self.get_provider_model_id(model_id)
|
||||
|
||||
response = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent?
|
||||
)
|
||||
extra_body = {}
|
||||
|
||||
if text_truncation is not None:
|
||||
text_truncation_options = {
|
||||
TextTruncation.none: "NONE",
|
||||
TextTruncation.end: "END",
|
||||
TextTruncation.start: "START",
|
||||
}
|
||||
extra_body["truncate"] = text_truncation_options[text_truncation]
|
||||
|
||||
if output_dimension is not None:
|
||||
extra_body["dimensions"] = output_dimension
|
||||
|
||||
if task_type is not None:
|
||||
task_type_options = {
|
||||
EmbeddingTaskType.document: "passage",
|
||||
EmbeddingTaskType.query: "query",
|
||||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
except BadRequestError as e:
|
||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
||||
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
|
||||
|
|
@ -169,7 +224,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -178,11 +233,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
|
|
@ -198,12 +256,12 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_chat_completion_stream(response)
|
||||
return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
|
|
|||
|
|
@ -4,249 +4,36 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
ImageURL as OpenAIImageURL,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
Function as OpenAIFunction,
|
||||
)
|
||||
from openai.types.completion import Completion as OpenAICompletion
|
||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
SystemMessage,
|
||||
TokenLogProbs,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
GreedySamplingStrategy,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
_convert_openai_finish_reason,
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
|
||||
|
||||
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||
"""
|
||||
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
|
||||
|
||||
ToolDefinition:
|
||||
tool_name: str | BuiltinTool
|
||||
description: Optional[str]
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
|
||||
ToolParamDefinition:
|
||||
param_type: str
|
||||
description: Optional[str]
|
||||
required: Optional[bool]
|
||||
default: Optional[Any]
|
||||
|
||||
|
||||
OpenAI spec -
|
||||
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
param_name: {
|
||||
"type": param_type,
|
||||
"description": description,
|
||||
"default": default,
|
||||
},
|
||||
...
|
||||
},
|
||||
"required": [param_name, ...],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""
|
||||
out = {
|
||||
"type": "function",
|
||||
"function": {},
|
||||
}
|
||||
function = out["function"]
|
||||
|
||||
if isinstance(tool.tool_name, BuiltinTool):
|
||||
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
|
||||
else:
|
||||
function.update(name=tool.tool_name)
|
||||
|
||||
if tool.description:
|
||||
function.update(description=tool.description)
|
||||
|
||||
if tool.parameters:
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
properties = parameters["properties"]
|
||||
required = []
|
||||
for param_name, param in tool.parameters.items():
|
||||
properties[param_name] = {"type": param.param_type}
|
||||
if param.description:
|
||||
properties[param_name].update(description=param.description)
|
||||
if param.default:
|
||||
properties[param_name].update(default=param.default)
|
||||
if param.required:
|
||||
required.append(param_name)
|
||||
|
||||
if required:
|
||||
parameters.update(required=required)
|
||||
|
||||
function.update(parameters=parameters)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||
"""
|
||||
Convert a Message to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# users can supply a dict instead of a Message object, we'll
|
||||
# convert it to a Message object and proceed with some type safety.
|
||||
if isinstance(message, dict):
|
||||
if "role" not in message:
|
||||
raise ValueError("role is required in message")
|
||||
if message["role"] == "user":
|
||||
message = UserMessage(**message)
|
||||
elif message["role"] == "assistant":
|
||||
message = CompletionMessage(**message)
|
||||
elif message["role"] == "tool":
|
||||
message = ToolResponseMessage(**message)
|
||||
elif message["role"] == "system":
|
||||
message = SystemMessage(**message)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message role: {message['role']}")
|
||||
|
||||
# Map Llama Stack spec to OpenAI spec -
|
||||
# str -> str
|
||||
# {"type": "text", "text": ...} -> {"type": "text", "text": ...}
|
||||
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
|
||||
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
|
||||
# List[...] -> List[...]
|
||||
async def _convert_user_message_content(
|
||||
content: InterleavedContent,
|
||||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||
# Llama Stack and OpenAI spec match for str and text input
|
||||
if isinstance(content, str) or isinstance(content, TextContentItem):
|
||||
return content
|
||||
elif isinstance(content, ImageContentItem):
|
||||
return OpenAIChatCompletionContentPartImageParam(
|
||||
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
|
||||
type="image_url",
|
||||
)
|
||||
elif isinstance(content, List):
|
||||
return [await _convert_user_message_content(item) for item in content]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
|
||||
out: OpenAIChatCompletionMessage = None
|
||||
if isinstance(message, UserMessage):
|
||||
out = OpenAIChatCompletionUserMessage(
|
||||
role="user",
|
||||
content=await _convert_user_message_content(message.content),
|
||||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=message.content,
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=tool.tool_name,
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
],
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=message.call_id,
|
||||
content=message.content,
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
out = OpenAIChatCompletionSystemMessage(
|
||||
role="system",
|
||||
content=message.content,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
return out
|
||||
|
||||
|
||||
async def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
n: int = 1,
|
||||
|
|
@ -281,7 +68,7 @@ async def convert_chat_completion_request(
|
|||
nvext = {}
|
||||
payload: Dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
messages=[await _convert_message(message) for message in request.messages],
|
||||
messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
|
||||
stream=request.stream,
|
||||
n=n,
|
||||
extra_body=dict(nvext=nvext),
|
||||
|
|
@ -296,7 +83,7 @@ async def convert_chat_completion_request(
|
|||
nvext.update(guided_json=request.response_format.json_schema)
|
||||
|
||||
if request.tools:
|
||||
payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools])
|
||||
payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools])
|
||||
if request.tool_config.tool_choice:
|
||||
payload.update(
|
||||
tool_choice=request.tool_config.tool_choice.value
|
||||
|
|
@ -319,7 +106,7 @@ async def convert_chat_completion_request(
|
|||
payload.update(temperature=strategy.temperature)
|
||||
elif isinstance(strategy, TopKSamplingStrategy):
|
||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=strategy.top_k)
|
||||
elif isinstance(strategy, GreedySamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
|
|
@ -329,239 +116,6 @@ async def convert_chat_completion_request(
|
|||
return payload
|
||||
|
||||
|
||||
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||
"""
|
||||
Convert an OpenAI chat completion finish_reason to a StopReason.
|
||||
|
||||
finish_reason: Literal["stop", "length", "tool_calls", ...]
|
||||
- stop: model hit a natural stop point or a provided stop sequence
|
||||
- length: maximum number of tokens specified in the request was reached
|
||||
- tool_calls: model called a tool
|
||||
|
||||
->
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
|
||||
# TODO(mf): are end_of_turn and end_of_message semantics correct?
|
||||
return {
|
||||
"stop": StopReason.end_of_turn,
|
||||
"length": StopReason.out_of_tokens,
|
||||
"tool_calls": StopReason.end_of_message,
|
||||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: List[OpenAIChatCompletionMessageToolCall],
|
||||
) -> List[ToolCall]:
|
||||
"""
|
||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||
|
||||
OpenAI ChatCompletionMessageToolCall:
|
||||
id: str
|
||||
function: Function
|
||||
type: Literal["function"]
|
||||
|
||||
OpenAI Function:
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
->
|
||||
|
||||
ToolCall:
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: Dict[str, ...]
|
||||
"""
|
||||
if not tool_calls:
|
||||
return [] # CompletionMessage tool_calls is not optional
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def _convert_openai_logprobs(
|
||||
logprobs: OpenAIChoiceLogprobs,
|
||||
) -> Optional[List[TokenLogProbs]]:
|
||||
"""
|
||||
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
|
||||
|
||||
OpenAI ChoiceLogprobs:
|
||||
content: Optional[List[ChatCompletionTokenLogprob]]
|
||||
|
||||
OpenAI ChatCompletionTokenLogprob:
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: List[TopLogprob]
|
||||
|
||||
OpenAI TopLogprob:
|
||||
token: str
|
||||
logprob: float
|
||||
|
||||
->
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
|
||||
"""
|
||||
if not logprobs:
|
||||
return None
|
||||
|
||||
return [
|
||||
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
|
||||
for content in logprobs.content
|
||||
]
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Convert an OpenAI Choice into a ChatCompletionResponse.
|
||||
|
||||
OpenAI Choice:
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: str
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
OpenAI ChatCompletionMessage:
|
||||
role: Literal["assistant"]
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
|
||||
|
||||
->
|
||||
|
||||
ChatCompletionResponse:
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
|
||||
CompletionMessage:
|
||||
role: Literal["assistant"]
|
||||
content: str | ImageMedia | List[str | ImageMedia]
|
||||
stop_reason: StopReason
|
||||
tool_calls: List[ToolCall]
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
|
||||
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
|
||||
"error in server response: finish_reason not found"
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content or "", # CompletionMessage content is not optional
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
||||
async def convert_openai_chat_completion_stream(
|
||||
stream: AsyncStream[OpenAIChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
"""
|
||||
Convert a stream of OpenAI chat completion chunks into a stream
|
||||
of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
|
||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
|
||||
yield ChatCompletionResponseEventType.start
|
||||
while True:
|
||||
yield ChatCompletionResponseEventType.progress
|
||||
|
||||
event_type = _event_type_generator()
|
||||
|
||||
# we implement NIM specific semantics, the main difference from OpenAI
|
||||
# is that tool_calls are always produced as a complete call. there is no
|
||||
# intermediate / partial tool call streamed. because of this, we can
|
||||
# simplify the logic and not concern outselves with parse_status of
|
||||
# started/in_progress/failed. we can always assume success.
|
||||
#
|
||||
# a stream of ChatCompletionResponseStreamChunk consists of
|
||||
# 0. a start event
|
||||
# 1. zero or more progress events
|
||||
# - each progress event has a delta
|
||||
# - each progress event may have a stop_reason
|
||||
# - each progress event may have logprobs
|
||||
# - each progress event may have tool_calls
|
||||
# if a progress event has tool_calls,
|
||||
# it is fully formed and
|
||||
# can be emitted with a parse_status of success
|
||||
# 2. a complete event
|
||||
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||
|
||||
# we assume there's only one finish_reason in the stream
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
|
||||
|
||||
# if there's a tool call, emit an event for each tool in the list
|
||||
# if tool call and content, emit both separately
|
||||
|
||||
if choice.delta.tool_calls:
|
||||
# the call may have content and a tool call. ChatCompletionResponseEvent
|
||||
# does not support both, so we emit the content first
|
||||
if choice.delta.content:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=TextDelta(text=choice.delta.content),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
# it is possible to have parallel tool calls in stream, but
|
||||
# ChatCompletionResponseEvent only supports one per stream
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
|
||||
|
||||
# NIM only produces fully formed tool calls, so we can assume success
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=ToolCallDelta(
|
||||
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def convert_completion_request(
|
||||
request: CompletionRequest,
|
||||
n: int = 1,
|
||||
|
|
@ -614,7 +168,7 @@ def convert_completion_request(
|
|||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
|
@ -34,6 +34,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
|
@ -58,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import model_entries
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
|
|
@ -71,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError as e:
|
||||
|
|
@ -89,11 +90,13 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -144,7 +147,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -153,6 +156,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -203,12 +208,15 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||
|
||||
return {
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logger.debug(f"params to ollama: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
|
@ -283,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding:
|
||||
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
response = await self.client.list()
|
||||
else:
|
||||
|
|
|
|||
23
llama_stack/providers/remote/inference/openai/__init__.py
Normal file
23
llama_stack/providers/remote/inference/openai/__init__.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: Optional[str] = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||
from .openai import OpenAIInferenceAdapter
|
||||
|
||||
impl = OpenAIInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
32
llama_stack/providers/remote/inference/openai/config.py
Normal file
32
llama_stack/providers/remote/inference/openai/config.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
30
llama_stack/providers/remote/inference/openai/models.py
Normal file
30
llama_stack/providers/remote/inference/openai/models.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openai/chatgpt-4o-latest",
|
||||
]
|
||||
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="openai/text-embedding-3-small",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1536, "context_length": 8192},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="openai/text-embedding-3-large",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 3072, "context_length": 8192},
|
||||
),
|
||||
]
|
||||
27
llama_stack/providers/remote/inference/openai/openai.py
Normal file
27
llama_stack/providers/remote/inference/openai/openai.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
|
@ -4,12 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
|
@ -24,6 +27,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
|
@ -46,7 +50,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
def _get_client(self) -> LlamaStackClient:
|
||||
def _get_client(self) -> AsyncLlamaStackClient:
|
||||
passthrough_url = None
|
||||
passthrough_api_key = None
|
||||
provider_data = None
|
||||
|
|
@ -71,7 +75,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
)
|
||||
passthrough_api_key = provider_data.passthrough_api_key
|
||||
|
||||
return LlamaStackClient(
|
||||
return AsyncLlamaStackClient(
|
||||
base_url=passthrough_url,
|
||||
api_key=passthrough_api_key,
|
||||
provider_data=provider_data,
|
||||
|
|
@ -81,15 +85,17 @@ class PassthroughInferenceAdapter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
params = {
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"content": content,
|
||||
"sampling_params": sampling_params,
|
||||
|
|
@ -98,16 +104,19 @@ class PassthroughInferenceAdapter(Inference):
|
|||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
# only pass through the not None params
|
||||
return client.inference.completion(**params)
|
||||
return await client.inference.completion(**json_params)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -116,10 +125,16 @@ class PassthroughInferenceAdapter(Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
params = {
|
||||
# TODO: revisit this remove tool_calls from messages logic
|
||||
for message in messages:
|
||||
if hasattr(message, "tool_calls"):
|
||||
message.tool_calls = None
|
||||
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"messages": messages,
|
||||
"sampling_params": sampling_params,
|
||||
|
|
@ -131,10 +146,41 @@ class PassthroughInferenceAdapter(Inference):
|
|||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
params = {key: value for key, value in params.items() if value is not None}
|
||||
|
||||
# only pass through the not None params
|
||||
return client.inference.chat_completion(**params)
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(json_params)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(json_params)
|
||||
|
||||
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
client = self._get_client()
|
||||
response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=response.completion_message.content.text,
|
||||
stop_reason=response.completion_message.stop_reason,
|
||||
tool_calls=response.completion_message.tool_calls,
|
||||
),
|
||||
logprobs=response.logprobs,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
stream_response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
async for chunk in stream_response:
|
||||
chunk = chunk.to_dict()
|
||||
|
||||
# temporary hack to remove the metrics from the response
|
||||
chunk["metrics"] = []
|
||||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -147,10 +193,29 @@ class PassthroughInferenceAdapter(Inference):
|
|||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
return client.inference.embeddings(
|
||||
return await client.inference.embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
json_params = {}
|
||||
for key, value in request_params.items():
|
||||
json_input = convert_pydantic_to_json_value(value)
|
||||
if isinstance(json_input, dict):
|
||||
json_input = {k: v for k, v in json_input.items() if v is not None}
|
||||
elif isinstance(json_input, list):
|
||||
json_input = [x for x in json_input if x is not None]
|
||||
new_input = []
|
||||
for x in json_input:
|
||||
if isinstance(x, dict):
|
||||
x = {k: v for k, v in x.items() if v is not None}
|
||||
new_input.append(x)
|
||||
json_input = new_input
|
||||
|
||||
json_params[key] = json_input
|
||||
|
||||
return json_params
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = RunpodInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -21,3 +21,10 @@ class RunpodImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.models.llama.datatypes import Message
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
@ -54,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -65,7 +64,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -74,6 +73,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
|||
|
|
@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
|
@ -74,7 +72,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
@ -85,7 +83,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -94,6 +92,8 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
tool_config: Optional[ToolConfig] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
|
|
@ -291,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if not tool_calls:
|
||||
return []
|
||||
|
||||
for call in tool_calls:
|
||||
call_function_arguments = json.loads(call.function.arguments)
|
||||
|
||||
compitable_tool_calls = [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call_function_arguments,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleInferenceImpl
|
||||
|
||||
impl = SampleInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleInferenceImpl(Inference):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -98,11 +98,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -201,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -210,6 +212,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
"api_key": "${env.TOGETHER_API_KEY:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from together import Together
|
||||
from together import AsyncTogether
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -31,9 +31,8 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
|
|
@ -53,27 +52,34 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -88,34 +94,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self) -> Together:
|
||||
together_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return Together(api_key=together_api_key)
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
if not self._client:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
self._client = AsyncTogether(api_key=together_api_key)
|
||||
return self._client
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
client = self._get_client()
|
||||
r = await client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
client = await self._get_client()
|
||||
stream = await client.completions.create(**params)
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
|
|
@ -150,7 +154,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
|
|
@ -159,6 +163,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -178,25 +184,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
r = self._get_client().chat.completions.create(**params)
|
||||
r = await client.chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
r = await client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
client = self._get_client()
|
||||
if "messages" in params:
|
||||
stream = await client.chat.completions.create(**params)
|
||||
else:
|
||||
stream = await client.completions.create(**params)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
s = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
|
|
@ -213,12 +215,14 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
assert not media_present, "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
return {
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||
}
|
||||
logger.debug(f"params to together: {params}")
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -232,7 +236,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Together does not support media for embeddings"
|
||||
)
|
||||
r = self._get_client().embeddings.create(
|
||||
client = self._get_client()
|
||||
r = await client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,10 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default="fake",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool = Field(
|
||||
default=True,
|
||||
description="Whether to verify TLS certificates",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
@ -36,4 +40,5 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
"url": url,
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:true}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,11 @@ import json
|
|||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import StopReason, ToolCall
|
||||
from openai import OpenAI
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -42,7 +45,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
|
@ -50,7 +53,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionResponse,
|
||||
UnparseableToolCall,
|
||||
convert_message_to_openai_dict,
|
||||
convert_tool_call,
|
||||
|
|
@ -88,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
|
|||
if not tool_calls:
|
||||
return []
|
||||
|
||||
call_function_arguments = None
|
||||
for call in tool_calls:
|
||||
call_function_arguments = json.loads(call.function.arguments)
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call_function_arguments,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
|
@ -156,11 +155,14 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
|||
|
||||
|
||||
async def _process_vllm_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
|
||||
) -> AsyncGenerator:
|
||||
event_type = ChatCompletionResponseEventType.start
|
||||
tool_call_buf = UnparseableToolCall()
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
|
||||
continue
|
||||
choice = chunk.choices[0]
|
||||
if choice.finish_reason:
|
||||
args_str = tool_call_buf.arguments
|
||||
|
|
@ -169,7 +171,7 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
args = {} if not args_str else json.loads(args_str)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
|
||||
if args is not None:
|
||||
if args:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
|
|
@ -178,12 +180,13 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
call_id=tool_call_buf.call_id,
|
||||
tool_name=tool_call_buf.tool_name,
|
||||
arguments=args,
|
||||
arguments_json=args_str,
|
||||
),
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
elif args_str:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
|
@ -225,7 +228,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
@ -237,11 +244,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
|
@ -260,7 +269,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -269,7 +278,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# * https://github.com/vllm-project/vllm/pull/10000
|
||||
if not tools and tool_config is not None:
|
||||
tool_config.tool_choice = ToolChoice.none
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
|
|
@ -286,10 +303,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return await self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = client.chat.completions.create(**params)
|
||||
r = await client.chat.completions.create(**params)
|
||||
choice = r.choices[0]
|
||||
result = ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
|
|
@ -301,17 +318,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
|
||||
# generator so this wrapper is not necessary?
|
||||
async def _to_async_generator():
|
||||
s = client.chat.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
stream = await client.chat.completions.create(**params)
|
||||
if len(request.tools) > 0:
|
||||
res = _process_vllm_chat_completion_stream_response(stream)
|
||||
else:
|
||||
|
|
@ -321,26 +331,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self.client.completions.create(**params)
|
||||
r = await self.client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# Wrapper for async generator similar
|
||||
async def _to_async_generator():
|
||||
stream = self.client.completions.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
stream = await self.client.completions.create(**params)
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = self.client.models.list()
|
||||
available_models = [m.id for m in res]
|
||||
res = await self.client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||
|
|
@ -396,7 +400,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
assert model.metadata.get("embedding_dimension")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
|
||||
response = self.client.embeddings.create(
|
||||
response = await self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
|
|
|
|||
|
|
@ -59,7 +59,8 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
"""
|
||||
This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
{
|
||||
"text": {
|
||||
|
|
@ -67,10 +68,8 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
}
|
||||
}
|
||||
]```
|
||||
However the incoming messages are of this type UserMessage(content=....) coming from
|
||||
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
|
||||
|
||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||
Incoming messages contain content, role . For now we will extract the content and
|
||||
default the "qualifiers": ["query"]
|
||||
"""
|
||||
|
||||
shield_params = shield.params
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleAgentsImpl
|
||||
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
|
||||
from .nvidia import NVIDIASafetyAdapter
|
||||
|
||||
impl = SampleAgentsImpl(config)
|
||||
impl = NVIDIASafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
37
llama_stack/providers/remote/safety/nvidia/config.py
Normal file
37
llama_stack/providers/remote/safety/nvidia/config.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIASafetyConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA Guardrail microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
|
||||
config_id (str): The ID of the guardrails configuration to use from the configuration store
|
||||
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
|
||||
|
||||
"""
|
||||
|
||||
guardrails_service_url: str = Field(
|
||||
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the guardrails service",
|
||||
)
|
||||
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
|
||||
"config_id": "self-check",
|
||||
}
|
||||
154
llama_stack/providers/remote/safety/nvidia/nvidia.py
Normal file
154
llama_stack/providers/remote/safety/nvidia/nvidia.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||
"""
|
||||
Initialize the NVIDIASafetyAdapter with a given safety configuration.
|
||||
|
||||
Args:
|
||||
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
|
||||
"""
|
||||
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if not shield.provider_resource_id:
|
||||
raise ValueError("Shield model not provided.")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None
|
||||
) -> RunShieldResponse:
|
||||
"""
|
||||
Run a safety shield check against the provided messages.
|
||||
|
||||
Args:
|
||||
shield_id (str): The unique identifier for the shield to be used.
|
||||
messages (List[Message]): A list of Message objects representing the conversation history.
|
||||
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: The response containing safety violation details if any.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shield with the provided shield_id is not found.
|
||||
"""
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
class NeMoGuardrails:
|
||||
"""
|
||||
A class that encapsulates NVIDIA's guardrails safety logic.
|
||||
|
||||
Sends messages to the guardrails service and interprets the response to determine
|
||||
if a safety violation has occurred.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NVIDIASafetyConfig,
|
||||
model: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize a NeMoGuardrails instance with the provided parameters.
|
||||
|
||||
Args:
|
||||
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
|
||||
model (str): The identifier or name of the model to be used for safety checks.
|
||||
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
|
||||
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
|
||||
|
||||
Raises:
|
||||
ValueError: If temperature is less than or equal to 0.
|
||||
AssertionError: If config_id is not provided in the configuration.
|
||||
"""
|
||||
self.config_id = config.config_id
|
||||
self.model = model
|
||||
assert self.config_id is not None, "Must provide config id"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
self.guardrails_service_url = config.guardrails_service_url
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
"""
|
||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||
|
||||
Args:
|
||||
messages (List[Message]): A list of Message objects to be checked for safety violations.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
|
||||
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
|
||||
|
||||
Raises:
|
||||
requests.HTTPError: If the POST request fails.
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_data = {
|
||||
"model": self.model,
|
||||
"messages": convert_pydantic_to_json_value(messages),
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": self.config_id,
|
||||
},
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
|
||||
)
|
||||
response.raise_for_status()
|
||||
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
|
||||
response_json = response.json()
|
||||
if response_json["status"] == "blocked":
|
||||
user_message = "Sorry I cannot do this."
|
||||
metadata = response_json["rails_status"]
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return RunShieldResponse(violation=None)
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleSafetyImpl
|
||||
|
||||
impl = SampleSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleSafetyImpl(Safety):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
|
@ -31,7 +31,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
|
@ -77,12 +77,13 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
"q": kwargs["query"],
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
url=self.url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=self.url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_response(response.json())))
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -14,3 +14,9 @@ class BingSearchToolConfig(BaseModel):
|
|||
|
||||
api_key: Optional[str] = None
|
||||
top_k: int = 3
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.BING_API_KEY:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
|
@ -30,7 +30,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
|
@ -74,8 +74,13 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
|
|||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": kwargs["query"]}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=url,
|
||||
params=payload,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = self._clean_brave_response(response.json())
|
||||
content_items = "\n".join([str(result) for result in results])
|
||||
return ToolInvocationResult(
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelContextProtocolConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
|
@ -30,7 +30,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
|
@ -66,10 +66,12 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": kwargs["query"]},
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": kwargs["query"]},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json())))
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -13,3 +13,9 @@ class WolframAlphaToolConfig(BaseModel):
|
|||
"""Configuration for WolframAlpha Tool Runtime"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
|
@ -31,7 +31,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
|
|
@ -73,11 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
"format": "plaintext",
|
||||
"output": "json",
|
||||
}
|
||||
response = requests.get(
|
||||
self.url,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(params=params, url=self.url)
|
||||
response.raise_for_status()
|
||||
return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json())))
|
||||
|
||||
def _clean_wolfram_alpha_response(self, wa_response):
|
||||
|
|
|
|||
|
|
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
url: str
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
return {"url": "{env.CHROMADB_URL}"}
|
||||
def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"url": url}
|
||||
|
|
|
|||
21
llama_stack/providers/remote/vector_io/milvus/__init__.py
Normal file
21
llama_stack/providers/remote/vector_io/milvus/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
22
llama_stack/providers/remote/vector_io/milvus/config.py
Normal file
22
llama_stack/providers/remote/vector_io/milvus/config.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
uri: str
|
||||
token: Optional[str] = None
|
||||
consistency_level: str = "Strong"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
||||
175
llama_stack/providers/remote/vector_io/milvus/milvus.py
Normal file
175
llama_stack/providers/remote/vector_io/milvus/milvus.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusIndex(EmbeddingIndex):
|
||||
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
|
||||
self.client = client
|
||||
self.collection_name = collection_name.replace("-", "_")
|
||||
self.consistency_level = consistency_level
|
||||
|
||||
async def delete(self):
|
||||
if self.client.has_collection(self.collection_name):
|
||||
self.client.drop_collection(collection_name=self.collection_name)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
if not self.client.has_collection(self.collection_name):
|
||||
self.client.create_collection(
|
||||
self.collection_name,
|
||||
dimension=len(embeddings[0]),
|
||||
auto_id=True,
|
||||
consistency_level=self.consistency_level,
|
||||
)
|
||||
|
||||
data = []
|
||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)
|
||||
|
||||
data.append(
|
||||
{
|
||||
"chunk_id": chunk_id,
|
||||
"vector": embedding,
|
||||
"chunk_content": chunk.model_dump(),
|
||||
}
|
||||
)
|
||||
try:
|
||||
self.client.insert(
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
search_res = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
limit=k,
|
||||
output_fields=["*"],
|
||||
search_params={"params": {"radius": score_threshold}},
|
||||
)
|
||||
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
||||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
consistency_level = "Strong"
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
||||
|
||||
# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file
|
||||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -16,3 +18,15 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
db: str = Field(default="postgres")
|
||||
user: str = Field(default="postgres")
|
||||
password: str = Field(default="mysecretpassword")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
host: str = "${env.PGVECTOR_HOST:localhost}",
|
||||
port: int = "${env.PGVECTOR_PORT:5432}",
|
||||
db: str = "${env.PGVECTOR_DB}",
|
||||
user: str = "${env.PGVECTOR_USER}",
|
||||
password: str = "${env.PGVECTOR_PASSWORD}",
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return {"host": host, "port": port, "db": db, "user": user, "password": password}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,11 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
def __init__(self, vector_db: VectorDB, dimension: int, conn):
|
||||
self.conn = conn
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
self.table_name = f"vector_store_{vector_db.identifier}"
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -23,4 +23,9 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
prefix: Optional[str] = None
|
||||
timeout: Optional[int] = None
|
||||
host: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.QDRANT_API_KEY}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
|
@ -16,12 +16,13 @@ from llama_stack.apis.inference import InterleavedContent
|
|||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
|
@ -99,17 +100,19 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: QdrantVectorIOConfig, inference_api: Api.inference) -> None:
|
||||
def __init__(
|
||||
self, config: Union[RemoteQdrantVectorIOConfig, InlineQdrantVectorIOConfig], inference_api: Api.inference
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
await self.client.close()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
@ -123,6 +126,11 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleVectorIOConfig, _deps) -> Any:
|
||||
from .sample import SampleVectorIOImpl
|
||||
|
||||
impl = SampleVectorIOImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleVectorIOConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
||||
from .config import SampleVectorIOConfig
|
||||
|
||||
|
||||
class SampleVectorIOImpl(VectorIO):
|
||||
def __init__(self, config: SampleVectorIOConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# these are the vector dbs the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -13,4 +15,6 @@ class WeaviateRequestProviderData(BaseModel):
|
|||
|
||||
|
||||
class WeaviateVectorIOConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue