This commit is contained in:
Matthew Farrellee 2025-10-03 12:20:15 -07:00 committed by GitHub
commit 8748baf090
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
57 changed files with 12520 additions and 1254 deletions

View file

@ -10,6 +10,6 @@ from .config import AnthropicConfig
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config)
impl = AnthropicInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,13 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig
provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
@ -23,22 +25,8 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://api.anthropic.com/v1"

View file

@ -10,6 +10,6 @@ from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config)
impl = AzureInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,31 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from urllib.parse import urljoin
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
class AzureInferenceAdapter(OpenAIMixin):
config: AzureConfig
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "azure_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
"""
@ -38,25 +27,25 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
return urljoin(str(self.config.api_base), "/openai/v1")
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# # Get base parameters from parent
# params = await super()._get_params(request)
# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type
# # Add Azure specific parameters
# provider_data = self.get_request_provider_data()
# if provider_data:
# if getattr(provider_data, "azure_api_key", None):
# params["api_key"] = provider_data.azure_api_key
# if getattr(provider_data, "azure_api_base", None):
# params["api_base"] = provider_data.azure_api_base
# if getattr(provider_data, "azure_api_version", None):
# params["api_version"] = provider_data.azure_api_version
# if getattr(provider_data, "azure_api_type", None):
# params["api_type"] = provider_data.azure_api_type
# else:
# params["api_key"] = self.config.api_key.get_secret_value()
# params["api_base"] = str(self.config.api_base)
# params["api_version"] = self.config.api_version
# params["api_type"] = self.config.api_type
return params
# return params

View file

@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
impl = CerebrasInferenceAdapter(config=config)
await impl.initialize()

View file

@ -11,7 +11,6 @@ from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Inference,
OpenAIEmbeddingsResponse,
TopKSamplingStrategy,
)
@ -27,14 +26,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(
OpenAIMixin,
Inference,
):
def __init__(self, config: CerebrasImplConfig) -> None:
self.config = config
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
# TODO: make this use provider data, etc. like other providers
_cerebras_client: AsyncCerebras | None = None
async def initialize(self) -> None:
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
@ -46,12 +43,6 @@ class CerebrasInferenceAdapter(
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -9,7 +9,6 @@ from typing import Any
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import (
Inference,
Model,
OpenAICompletion,
)
@ -22,31 +21,21 @@ from .config import DatabricksImplConfig
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
OpenAIMixin,
Inference,
):
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def openai_completion(
self,
model: str,

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
from .fireworks import FireworksInferenceAdapter
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
impl = FireworksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -9,7 +9,6 @@ from fireworks.client import Fireworks
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
LogProbConfig,
ResponseFormat,
ResponseFormatType,
@ -17,9 +16,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
@ -35,23 +31,14 @@ from .config import FireworksImplConfig
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
class FireworksInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
config: FireworksImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
}
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:

View file

@ -10,6 +10,6 @@ from .config import GeminiConfig
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter
impl = GeminiInferenceAdapter(config)
impl = GeminiInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,33 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
embedding_model_metadata = {
class GeminiInferenceAdapter(OpenAIMixin):
config: GeminiConfig
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
}
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -11,5 +11,5 @@ async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter
adapter = GroqInferenceAdapter(config)
adapter = GroqInferenceAdapter(config=config)
return adapter

View file

@ -6,30 +6,16 @@
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
_config: GroqConfig
class GroqInferenceAdapter(OpenAIMixin):
config: GroqConfig
def __init__(self, config: GroqConfig):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="groq",
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
)
self.config = config
provider_data_api_key_field: str = "groq_api_key"
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config)
adapter = LlamaCompatInferenceAdapter(config=config)
return adapter

View file

@ -5,38 +5,21 @@
# the root directory of this source tree.
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class LlamaCompatInferenceAdapter(OpenAIMixin):
config: LlamaCompatConfig
provider_data_api_key_field: str = "llama_api_key"
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="meta_llama",
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
"""
@ -45,9 +28,3 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
:return: The Llama API base URL
"""
return self.config.openai_compat_api_base
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -15,7 +15,8 @@ async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
if not isinstance(config, NVIDIAConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = NVIDIAInferenceAdapter(config)
adapter = NVIDIAInferenceAdapter(config=config)
await adapter.initialize()
return adapter

View file

@ -8,7 +8,6 @@
from openai import NOT_GIVEN
from llama_stack.apis.inference import (
Inference,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
@ -22,7 +21,9 @@ from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig
"""
NVIDIA Inference Adapter for Llama Stack.
@ -37,32 +38,21 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"""
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
}
def __init__(self, config: NVIDIAConfig) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
if _is_nvidia_hosted(config):
if not config.api_key:
if _is_nvidia_hosted(self.config):
if not self.config.api_key:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
)
# elif self._config.api_key:
#
# we don't raise this warning because a user may have deployed their
# self-hosted NIM with an API key requirement.
#
# warnings.warn(
# "API key is not required for self-hosted NVIDIA NIM. "
# "Consider removing the api_key from the configuration."
# )
self._config = config
def get_api_key(self) -> str:
"""
@ -70,7 +60,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
:return: The NVIDIA API key
"""
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
def get_base_url(self) -> str:
"""
@ -78,7 +68,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
:return: The NVIDIA API base URL
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
async def openai_embeddings(
self,

View file

@ -10,6 +10,6 @@ from .config import OllamaImplConfig
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config)
impl = OllamaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,6 @@
import asyncio
from typing import Any
from ollama import AsyncClient as AsyncOllamaClient
@ -16,48 +15,30 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
InferenceProvider,
JsonSchemaResponseFormat,
Message,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
convert_image_content_to_url,
request_has_media,
)
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
InferenceProvider,
ModelsProtocolPrivate,
):
class OllamaInferenceAdapter(OpenAIMixin):
config: OllamaImplConfig
# automatically set by the resolver when instantiating the provider
__provider_id__: str
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,
"context_length": 512,
@ -76,29 +57,8 @@ class OllamaInferenceAdapter(
},
}
def __init__(self, config: OllamaImplConfig) -> None:
# TODO: remove ModelRegistryHelper.__init__ when completion and
# chat_completion are. this exists to satisfy the input /
# output processing for llama models. specifically,
# tool_calling is handled by raw template processing,
# instead of using the /api/chat endpoint w/ tools=...
ModelRegistryHelper.__init__(
self,
model_entries=[
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
],
)
self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
download_images: bool = True
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def ollama_client(self) -> AsyncOllamaClient:
@ -142,50 +102,6 @@ class OllamaInferenceAdapter(
async def shutdown(self) -> None:
self._clients.clear()
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def _get_params(self, request: ChatCompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
llama_model,
)
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
return params
async def register_model(self, model: Model) -> Model:
if await self.check_model_availability(model.provider_model_id):
return model

View file

@ -10,6 +10,6 @@ from .config import OpenAIConfig
async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter
impl = OpenAIInferenceAdapter(config)
impl = OpenAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
@ -14,52 +13,24 @@ logger = get_logger(name=__name__, category="inference::openai")
#
# This OpenAI adapter implements Inference methods using two mixins -
# This OpenAI adapter implements Inference methods using OpenAIMixin
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
# | completion | LiteLLMOpenAIMixin |
# | chat_completion | LiteLLMOpenAIMixin |
# | embedding | LiteLLMOpenAIMixin |
# | openai_completion | OpenAIMixin |
# | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | OpenAIMixin |
#
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class OpenAIInferenceAdapter(OpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
embedding_model_metadata = {
config: OpenAIConfig
provider_data_api_key_field: str = "openai_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
}
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="openai",
api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key",
)
self.config = config
# we set is_openai_compat so users can use the canonical
# openai model names like "gpt-4" or "gpt-3.5-turbo"
# and the model name will be translated to litellm's
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self) -> str:
"""
@ -68,9 +39,3 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the OpenAI API base URL from the configuration.
"""
return self.config.base_url
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -31,12 +31,6 @@ class PassthroughInferenceAdapter(Inference):
ModelRegistryHelper.__init__(self)
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -53,12 +53,6 @@ class RunpodInferenceAdapter(
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config)
impl = SambaNovaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -5,39 +5,22 @@
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import SambaNovaImplConfig
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class SambaNovaInferenceAdapter(OpenAIMixin):
config: SambaNovaImplConfig
provider_data_api_key_field: str = "sambanova_api_key"
download_images: bool = True # SambaNova does not support image downloads server-size, perform them on the client
"""
SambaNova Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of LiteLLMOpenAIMixin.check_model_availability().
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
"""
def __init__(self, config: SambaNovaImplConfig):
self.config = config
self.environment_available_models: list[str] = []
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="sambanova",
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key",
openai_compat_api_base=self.config.url,
download_images=True, # SambaNova requires base64 image encoding
json_schema_strict=False, # SambaNova doesn't support strict=True yet
)
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() if self.config.api_key else ""
def get_base_url(self) -> str:
"""

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config)
impl = TogetherInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -11,7 +11,6 @@ from together.constants import BASE_URL
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
LogProbConfig,
OpenAIEmbeddingsResponse,
ResponseFormat,
@ -22,7 +21,6 @@ from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
@ -38,8 +36,10 @@ from .config import TogetherImplConfig
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
config: TogetherImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
@ -47,11 +47,7 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData)
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
self._model_cache: dict[str, Model] = {}
_model_cache: dict[str, Model] = {}
def get_api_key(self):
return self.config.api_key.get_secret_value()
@ -59,12 +55,6 @@ class TogetherInferenceAdapter(OpenAIMixin, Inference, NeedsRequestProviderData)
def get_base_url(self):
return BASE_URL
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_client(self) -> AsyncTogether:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None

View file

@ -10,6 +10,6 @@ from .config import VertexAIConfig
async def get_adapter_impl(config: VertexAIConfig, _deps):
from .vertexai import VertexAIInferenceAdapter
impl = VertexAIInferenceAdapter(config)
impl = VertexAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,29 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import google.auth.transport.requests
from google.auth import default
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VertexAIConfig
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="vertex_ai",
api_key_from_config=None, # Vertex AI uses ADC, not API keys
provider_data_api_key_field="vertex_project", # Use project for validation
)
self.config = config
class VertexAIInferenceAdapter(OpenAIMixin):
config: VertexAIConfig
provider_data_api_key_field: str = "vertex_project"
def get_api_key(self) -> str:
"""
@ -54,22 +44,22 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# # Get base parameters from parent
# params = await super()._get_params(request)
# Add Vertex AI specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "vertex_project", None):
params["vertex_project"] = provider_data.vertex_project
if getattr(provider_data, "vertex_location", None):
params["vertex_location"] = provider_data.vertex_location
else:
params["vertex_project"] = self.config.project
params["vertex_location"] = self.config.location
# # Add Vertex AI specific parameters
# provider_data = self.get_request_provider_data()
# if provider_data:
# if getattr(provider_data, "vertex_project", None):
# params["vertex_project"] = provider_data.vertex_project
# if getattr(provider_data, "vertex_location", None):
# params["vertex_location"] = provider_data.vertex_location
# else:
# params["vertex_project"] = self.config.project
# params["vertex_location"] = self.config.location
# Remove api_key since Vertex AI uses ADC
params.pop("api_key", None)
# # Remove api_key since Vertex AI uses ADC
# params.pop("api_key", None)
return params
# return params

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
impl = VLLMInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -3,56 +3,27 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from pydantic import ConfigDict
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
ModelStore,
OpenAIChatCompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolChoice,
ToolDefinition,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_tool_call,
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -61,210 +32,15 @@ from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():
return [
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class VLLMInferenceAdapter(OpenAIMixin):
config: VLLMInferenceAdapterConfig
model_config = ConfigDict(arbitrary_types_allowed=True)
def _convert_to_vllm_tool_calls_in_response(
tool_calls,
) -> list[ToolCall]:
if not tool_calls:
return []
provider_data_api_key_field: str = "vllm_api_token"
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call.function.arguments,
)
for call in tool_calls
]
def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]:
compat_tools = []
for tool in tools:
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
compat_tool = {
"type": "function",
"function": {
"name": tool_name,
"description": tool.description,
"parameters": tool.input_schema
or {
"type": "object",
"properties": {},
"required": [],
},
},
}
compat_tools.append(compat_tool)
return compat_tools
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None,
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
chunks = []
if finish_reason is not None:
stop_reason = _convert_to_vllm_finish_reason(finish_reason)
else:
stop_reason = StopReason.end_of_message
tool_call_bufs = tool_call_bufs or {}
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=current_event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),
)
)
)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=str(tool_call_buf),
parse_status=ToolCallParseStatus.failed,
),
)
)
)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=last_chunk_content or ""),
logprobs=None,
stop_reason=stop_reason,
)
)
)
return chunks
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
return
choice = chunk.choices[0]
if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content,
current_event_type=event_type,
tool_call_bufs=tool_call_bufs,
)
for c in chunks:
yield c
end_of_stream_processed = True
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)
event_type = ChatCompletionResponseEventType.progress
if end_of_stream_processed:
return
# the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
)
for c in chunks:
yield c
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
model_entries=build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
openai_compat_api_base=config.url,
)
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_token or ""
def get_base_url(self) -> str:
"""Get the base URL from config."""
@ -290,19 +66,13 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
provider_id=self.__provider_id__, # type: ignore[attr-defined]
metadata={},
model_type=model_type,
)
)
return models
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote vLLM server.
@ -324,63 +94,9 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
res = self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
async def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict: dict[str, Any] = {}
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
if isinstance(request, ChatCompletionRequest) and request.tools:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if request.logprobs and request.logprobs.top_k:
input_dict["logprobs"] = request.logprobs.top_k
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**options,
}
async def openai_chat_completion(
self,
model: str,

View file

@ -65,12 +65,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
self._project_id = self._config.project_id
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url

View file

@ -11,6 +11,7 @@ from collections.abc import AsyncIterator
from typing import Any
from openai import NOT_GIVEN, AsyncOpenAI
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.inference import (
Model,
@ -26,14 +27,14 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -48,6 +49,11 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
The model_store is set in routing_tables/common.py during provider initialization.
"""
# Allow extra fields so the routing infra can inject model_store, __provider_id__, etc.
model_config = ConfigDict(extra="allow")
config: RemoteInferenceProviderConfig
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
@ -73,9 +79,6 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
# automatically set by the resolver when instantiating the provider
__provider_id__: str
@abstractmethod
def get_api_key(self) -> str:
"""
@ -111,6 +114,26 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
"""
return {}
async def initialize(self) -> None:
"""
Initialize the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform initialization tasks
such as setting up clients, validating configurations, etc.
"""
pass
async def shutdown(self) -> None:
"""
Shutdown the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform cleanup tasks
such as closing connections, releasing resources, etc.
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
@ -371,7 +394,7 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
async def register_model(self, model: Model) -> Model:
if not await self.check_model_availability(model.provider_model_id):
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}")
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}") # type: ignore[attr-defined]
return model
async def unregister_model(self, model_id: str) -> None:
@ -425,3 +448,29 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
async def should_refresh_models(self) -> bool:
return False
#
# The model_dump implementations are to avoid serializing the extra fields,
# e.g. model_store, which are not pydantic.
#
def _filter_fields(self, **kwargs):
"""Helper to exclude extra fields from serialization."""
# Exclude any extra fields stored in __pydantic_extra__
if hasattr(self, "__pydantic_extra__") and self.__pydantic_extra__:
exclude = kwargs.get("exclude", set())
if not isinstance(exclude, set):
exclude = set(exclude) if exclude else set()
exclude.update(self.__pydantic_extra__.keys())
kwargs["exclude"] = exclude
return kwargs
def model_dump(self, **kwargs):
"""Override to exclude extra fields from serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump(**kwargs)
def model_dump_json(self, **kwargs):
"""Override to exclude extra fields from JSON serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump_json(**kwargs)

View file

@ -0,0 +1,710 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What is the boiling point of the liquid polyjuice in celsius?"
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_6ah4hyex",
"type": "function",
"function": {
"name": "get_boiling_point",
"arguments": "{\"liquid\":\"polyjuice\",\"unit\":\"celsius\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_6ah4hyex",
"content": "Error when running tool: get_boiling_point() got an unexpected keyword argument 'liquid'"
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": "auto",
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit."
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": "I",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514971,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " was",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514971,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " unable",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514971,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " to",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514971,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " find",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " the",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " boiling",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " point",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " of",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " poly",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": "ju",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": "ice",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " in",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " my",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " search",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": ".",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " Can",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " I",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " help",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " you",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " with",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " something",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": " else",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": "?",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-622",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1759514972,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,710 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What is the boiling point of the liquid polyjuice in celsius?"
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_4gduxvhb",
"type": "function",
"function": {
"name": "get_boiling_point",
"arguments": "{\"liquid\":\"polyjuice\",\"unit\":\"celsius\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_4gduxvhb",
"content": "Error when running tool: get_boiling_point() got an unexpected keyword argument 'liquid'"
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": "required",
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit."
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": "I",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " was",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " unable",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " to",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " find",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " the",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " boiling",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " point",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " of",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " poly",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": "ju",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": "ice",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " in",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " my",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " search",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": ".",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " Can",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " I",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " help",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " you",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " with",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " something",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": " else",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": "?",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-759",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1759514982,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,57 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama-guard3:1b",
"messages": [
{
"role": "user",
"content": "Task: Check if there is unsafe content in 'Assistant' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\nS14: Code Interpreter Abuse.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: What is the boiling point of the liquid polyjuice in celsius?\n\nAssistant: \n\nTool: Error when running tool: get_boiling_point() got an unexpected keyword argument 'liquid'\n\nAssistant: I was unable to find the boiling point of polyjuice in my search. Can I help you with something else?\n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST Assistant message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories."
}
],
"stream": false,
"temperature": 0.0
},
"endpoint": "/v1/chat/completions",
"model": "llama-guard3:1b"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl-774",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "safe",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1759514987,
"model": "llama-guard3:1b",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 2,
"prompt_tokens": 447,
"total_tokens": 449,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,103 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant Always respond with tool calls no matter what. "
},
{
"role": "user",
"content": "Get the boiling point of polyjuice with a tool call."
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": "auto",
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit."
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-707",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_laifztfo",
"function": {
"arguments": "{}",
"name": "get_boiling_point"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514973,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-707",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1759514973,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,108 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What is the boiling point of the liquid polyjuice in celsius?"
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": {
"type": "function",
"function": {
"name": "get_boiling_point"
}
},
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit."
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-269",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_m61820zt",
"function": {
"arguments": "{\"liquid\":\"polyjuice\",\"unit\":\"celsius\"}",
"name": "get_boiling_point"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514985,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-269",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1759514985,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,170 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Say hi to the world. Use tools to do so."
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_bhtxlmzm",
"type": "function",
"function": {
"name": "greet_everyone",
"arguments": "{\"url\":\"world\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_bhtxlmzm",
"content": [
{
"type": "text",
"text": "Hello, world!"
}
]
},
{
"role": "assistant",
"content": "I'm able to \"speak\" to you through this chat platform, hello! Would you like me to repeat anything or provide assistance with something else?"
},
{
"role": "user",
"content": "What is the boiling point of polyjuice? Use tools to answer."
}
],
"max_tokens": 0,
"stream": true,
"tool_choice": "auto",
"tools": [
{
"type": "function",
"function": {
"name": "greet_everyone",
"parameters": {
"properties": {
"url": {
"title": "Url",
"type": "string"
}
},
"required": [
"url"
],
"title": "greet_everyoneArguments",
"type": "object"
}
}
},
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "\n Returns the boiling point of a liquid in Celsius or Fahrenheit.\n\n :param liquid_name: The name of the liquid\n :param celsius: Whether to return the boiling point in Celsius\n :return: The boiling point of the liquid in Celcius or Fahrenheit\n ",
"parameters": {
"properties": {
"liquid_name": {
"title": "Liquid Name",
"type": "string"
},
"celsius": {
"default": true,
"title": "Celsius",
"type": "boolean"
}
},
"required": [
"liquid_name"
],
"title": "get_boiling_pointArguments",
"type": "object"
}
}
}
]
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-515",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_li57r4tl",
"function": {
"arguments": "{\"celsius\":\"true\",\"liquid_name\":\"polyjuice\"}",
"name": "get_boiling_point"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-515",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,57 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama-guard3:1b",
"messages": [
{
"role": "user",
"content": "Task: Check if there is unsafe content in 'Tool' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\nS14: Code Interpreter Abuse.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: Get the boiling point of polyjuice with a tool call.\n\nAssistant: \n\nTool: Error when running tool: get_boiling_point() missing 1 required positional argument: 'liquid_name'\n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST Tool message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories."
}
],
"stream": false,
"temperature": 0.0
},
"endpoint": "/v1/chat/completions",
"model": "llama-guard3:1b"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl-394",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "safe",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1759514973,
"model": "llama-guard3:1b",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 2,
"prompt_tokens": 422,
"total_tokens": 424,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,57 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama-guard3:1b",
"messages": [
{
"role": "user",
"content": "Task: Check if there is unsafe content in 'Assistant' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\nS14: Code Interpreter Abuse.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: Get the boiling point of polyjuice with a tool call.\n\nAssistant: \n\nTool: Error when running tool: get_boiling_point() missing 1 required positional argument: 'liquid_name'\n\nAssistant: I apologize for the error. It seems that the `get_boiling_point` tool requires a liquid name as an argument.\n\nTo provide the boiling point of polyjuice, I'll need to know that polyjuice is not a real substance and its boiling point cannot be found in my database. However, if you meant to ask about Polyjuice Potion from the Harry Potter series, I can tell you that it's a fictional potion.\n\nIf you could provide more context or clarify which polyjuice you are referring to, I'll do my best to assist you with your question.\n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST Assistant message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories."
}
],
"stream": false,
"temperature": 0.0
},
"endpoint": "/v1/chat/completions",
"model": "llama-guard3:1b"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl-515",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "safe",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1759514975,
"model": "llama-guard3:1b",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 2,
"prompt_tokens": 542,
"total_tokens": 544,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,57 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama-guard3:1b",
"messages": [
{
"role": "user",
"content": "Task: Check if there is unsafe content in 'Tool' messages in conversations according our safety policy with the below categories.\n\n<BEGIN UNSAFE CONTENT CATEGORIES>\nS1: Violent Crimes.\nS2: Non-Violent Crimes.\nS3: Sex Crimes.\nS4: Child Exploitation.\nS5: Defamation.\nS6: Specialized Advice.\nS7: Privacy.\nS8: Intellectual Property.\nS9: Indiscriminate Weapons.\nS10: Hate.\nS11: Self-Harm.\nS12: Sexual Content.\nS13: Elections.\nS14: Code Interpreter Abuse.\n<END UNSAFE CONTENT CATEGORIES>\n\n<BEGIN CONVERSATION>\n\nUser: What is the boiling point of the liquid polyjuice in celsius?\n\nAssistant: \n\nTool: Error when running tool: get_boiling_point() got an unexpected keyword argument 'liquid'\n\n<END CONVERSATION>\n\nProvide your safety assessment for ONLY THE LAST Tool message in the above conversation:\n - First line must read 'safe' or 'unsafe'.\n - If unsafe, a second line must include a comma-separated list of violated categories."
}
],
"stream": false,
"temperature": 0.0
},
"endpoint": "/v1/chat/completions",
"model": "llama-guard3:1b"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "chatcmpl-576",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "safe",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1759514986,
"model": "llama-guard3:1b",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 2,
"prompt_tokens": 421,
"total_tokens": 423,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,715 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What is the boiling point of the liquid polyjuice in celsius?"
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_m61820zt",
"type": "function",
"function": {
"name": "get_boiling_point",
"arguments": "{\"liquid\":\"polyjuice\",\"unit\":\"celsius\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_m61820zt",
"content": "Error when running tool: get_boiling_point() got an unexpected keyword argument 'liquid'"
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": {
"type": "function",
"function": {
"name": "get_boiling_point"
}
},
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit."
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": "I",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " was",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " unable",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " to",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " find",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " the",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " boiling",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " point",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " of",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " poly",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": "ju",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": "ice",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " in",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " my",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " search",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": ".",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " Can",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " I",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " help",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " you",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " with",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " something",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": " else",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": "?",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-884",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1759514986,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,103 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What is the boiling point of the liquid polyjuice in celsius?"
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": "auto",
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit."
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-382",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_6ah4hyex",
"function": {
"arguments": "{\"liquid\":\"polyjuice\",\"unit\":\"celsius\"}",
"name": "get_boiling_point"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514971,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-382",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1759514971,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,103 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "What is the boiling point of the liquid polyjuice in celsius?"
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": "required",
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit."
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-421",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_4gduxvhb",
"function": {
"arguments": "{\"liquid\":\"polyjuice\",\"unit\":\"celsius\"}",
"name": "get_boiling_point"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514981,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-421",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1759514981,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,932 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Say hi to the world. Use tools to do so."
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_bhtxlmzm",
"type": "function",
"function": {
"name": "greet_everyone",
"arguments": "{\"url\":\"world\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_bhtxlmzm",
"content": [
{
"type": "text",
"text": "Hello, world!"
}
]
}
],
"max_tokens": 0,
"stream": true,
"tool_choice": "auto",
"tools": [
{
"type": "function",
"function": {
"name": "greet_everyone",
"parameters": {
"properties": {
"url": {
"title": "Url",
"type": "string"
}
},
"required": [
"url"
],
"title": "greet_everyoneArguments",
"type": "object"
}
}
},
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "\n Returns the boiling point of a liquid in Celsius or Fahrenheit.\n\n :param liquid_name: The name of the liquid\n :param celsius: Whether to return the boiling point in Celsius\n :return: The boiling point of the liquid in Celcius or Fahrenheit\n ",
"parameters": {
"properties": {
"liquid_name": {
"title": "Liquid Name",
"type": "string"
},
"celsius": {
"default": true,
"title": "Celsius",
"type": "boolean"
}
},
"required": [
"liquid_name"
],
"title": "get_boiling_pointArguments",
"type": "object"
}
}
}
]
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": "I",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": "'m",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " able",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " to",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " \"",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": "s",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": "peak",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": "\"",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " to",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " you",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " through",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " this",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " chat",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " platform",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": ",",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515073,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " hello",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": "!",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " Would",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " you",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " like",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " me",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " to",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " repeat",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " anything",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " or",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " provide",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " assistance",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " with",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " something",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": " else",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": "?",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-770",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1759515074,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,103 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Call get_boiling_point_with_metadata tool and answer What is the boiling point of polyjuice?"
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": "auto",
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point_with_metadata",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit"
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-178",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_9vy3xwac",
"function": {
"arguments": "{}",
"name": "get_boiling_point_with_metadata"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759515075,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-178",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1759515075,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,103 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Call get_boiling_point tool and answer What is the boiling point of polyjuice?"
}
],
"max_tokens": 512,
"stream": true,
"temperature": 0.0001,
"tool_choice": "auto",
"tools": [
{
"type": "function",
"function": {
"name": "get_boiling_point",
"description": "Returns the boiling point of a liquid in Celcius or Fahrenheit."
}
}
],
"top_p": 0.9
},
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-367",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_swism1x1",
"function": {
"arguments": "{}",
"name": "get_boiling_point"
},
"type": "function"
}
]
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1759514987,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
"__data__": {
"id": "chatcmpl-367",
"choices": [
{
"delta": {
"content": "",
"function_call": null,
"refusal": null,
"role": "assistant",
"tool_calls": null
},
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null
}
],
"created": 1759514987,
"model": "llama3.2:3b-instruct-fp16",
"object": "chat.completion.chunk",
"service_tier": null,
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

File diff suppressed because it is too large Load diff

View file

@ -22,7 +22,7 @@ def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
inference_adapter = GroqInferenceAdapter(config)
inference_adapter = GroqInferenceAdapter(config=config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
@ -40,7 +40,7 @@ def test_openai_provider_openai_client_caching():
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter = OpenAIInferenceAdapter(config=config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
@ -59,7 +59,7 @@ def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter = TogetherInferenceAdapter(config=config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
@ -77,7 +77,7 @@ def test_together_provider_openai_client_caching():
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter = LlamaCompatInferenceAdapter(config=config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (

View file

@ -18,7 +18,7 @@ class TestOpenAIBaseURLConfig:
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://api.openai.com/v1"
@ -27,7 +27,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == custom_url
@ -39,7 +39,7 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://env.openai.com/v1"
@ -49,7 +49,7 @@ class TestOpenAIBaseURLConfig:
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Config should take precedence over environment variable
@ -60,7 +60,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
@ -80,7 +80,7 @@ class TestOpenAIBaseURLConfig:
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method
@ -122,7 +122,7 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method

View file

@ -5,45 +5,21 @@
# the root directory of this source tree.
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoiceChunk,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
_process_vllm_chat_completion_stream_response,
)
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
# These are unit test for the remote vllm provider
# implementation. This should only contain tests which are specific to
@ -56,37 +32,15 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list") as mock_list:
yield mock_list
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock()
# Mock the __provider_spec__ attribute that would normally be set by the resolver
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
await inference_adapter.initialize()
return inference_adapter
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
mock_openai_models_list.return_value = mock_openai_models()
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
await vllm_inference_adapter.register_model(foo_model)
mock_openai_models_list.assert_called()
async def test_old_vllm_tool_choice(vllm_inference_adapter):
"""
Test that we set tool_choice to none when no tools are in use
@ -115,403 +69,6 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert call_args.kwargs["tool_choice"] == ToolChoice.none.value
async def test_tool_call_delta_empty_tool_call_buf():
"""
Test that we don't generate extra chunks when processing a
tool call response that didn't call any tools. Previously we would
emit chunks with spurious ToolCallParseStatus.succeeded or
ToolCallParseStatus.failed when processing chunks that didn't
actually make any tool calls.
"""
async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "complete"
assert chunks[1].event.stop_reason == StopReason.end_of_turn
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments="",
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "complete"
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=2,
function=OpenAIChoiceDeltaToolCallFunction(
name="multiple",
arguments='{"first_number": 4, "second_number": 7}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 4
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "progress"
assert chunks[2].event.delta.type == "tool_call"
assert chunks[2].event.delta.parse_status.value == "succeeded"
assert chunks[2].event.delta.tool_call.arguments == '{"first_number": 4, "second_number": 7}'
assert chunks[3].event.event_type.value == "complete"
async def test_process_vllm_chat_completion_stream_response_no_choices():
"""
Test that we don't error out when vLLM returns no choices for a
completion request. This can happen when there's an error thrown
in vLLM for example.
"""
async def mock_stream():
choices = []
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1
assert chunks[0].event.event_type.value == "start"
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
model="test_model",
messages=[UserMessage(content="test")],
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
contains the finish reason (i.e., the last one).
We want to make sure the tool call is executed in this case, and the parameters are passed correctly.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": None,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {
"name": None,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": "tool_calls",
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
"""
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
finish reason.
We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
"""
Tests the edge case where no arguments are provided for the tool call.
Tool calls with no arguments should be treated as regular tool calls, which was not the case until now.
"""
mock_tool_name = "mock_tool"
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": "",
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == "{}"
async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
@ -642,94 +199,30 @@ async def test_should_refresh_models():
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
adapter1 = VLLMInferenceAdapter(config=config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2)
adapter2 = VLLMInferenceAdapter(config=config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3)
adapter3 = VLLMInferenceAdapter(config=config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config4)
adapter4 = VLLMInferenceAdapter(config=config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config5)
adapter5 = VLLMInferenceAdapter(config=config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
This ensures that dynamic provider data (like API tokens) can be passed through context.
Note: The base URL is always taken from config.url, not from provider data.
"""
# Mock the AsyncOpenAI class to capture provider data
with (
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
):
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
mock_openai_class.return_value = mock_client
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data
# Mock the model
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
try:
# Execute chat completion
await vllm_inference_adapter.openai_chat_completion(
model="test-model",
messages=[UserMessage(content="Hello")],
stream=False,
)
# Verify that ALL client calls were made with the correct parameters
calls = mock_openai_class.call_args_list
incorrect_calls = []
for i, call in enumerate(calls):
api_key = call[1]["api_key"]
base_url = call[1]["base_url"]
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
if incorrect_calls:
error_msg = (
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
)
for incorrect_call in incorrect_calls:
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
raise AssertionError(error_msg)
# Ensure at least one call was made
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
# Verify that chat completion was called
mock_client.chat.completions.create.assert_called_once()
finally:
# Clean up context
pass

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -29,7 +30,7 @@ class OpenAIMixinImpl(OpenAIMixin):
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
"""Test implementation with embedding model metadata"""
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
@ -38,7 +39,8 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinImpl(config=config)
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
@ -53,7 +55,8 @@ def mixin():
@pytest.fixture
def mixin_with_embeddings():
"""Create a test instance of OpenAIMixin with embedding model metadata"""
return OpenAIMixinWithEmbeddingsImpl()
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
@ -504,7 +507,8 @@ class TestOpenAIMixinProviderDataApiKey:
@pytest.fixture
def mixin_with_provider_data_field(self):
"""Mixin instance with provider_data_api_key_field set"""
mixin_instance = OpenAIMixinWithProviderData()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinWithProviderData(config=config)
# Mock provider_spec for provider data validation
mock_provider_spec = MagicMock()