feat: small ollama package

This commit is contained in:
Raghotham Murthy 2025-05-28 21:13:48 -07:00
commit 2d5d05a2b4
103 changed files with 7262 additions and 7422 deletions

View file

@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool
from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.schema_utils import json_schema_type
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
class ToolsProtocolPrivate(Protocol):
async def register_tool(self, tool: Tool) -> None: ...
class ToolGroupsProtocolPrivate(Protocol):
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
async def unregister_tool(self, tool_id: str) -> None: ...
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
@json_schema_type

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import time
import uuid
from collections.abc import AsyncIterator
from typing import Any, cast
@ -29,10 +30,12 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.inference.inference import (
@ -255,110 +258,14 @@ class OpenAIResponsesImpl:
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def create_openai_response(
async def _process_response_choices(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
chat_response: OpenAIChatCompletion,
ctx: ChatCompletionContext,
tools: list[OpenAIResponseInputTool] | None,
) -> list[OpenAIResponseOutput]:
"""Handle tool execution and response message creation."""
output_messages: list[OpenAIResponseOutput] = []
stream = False if stream is None else stream
# Huge TODO: we need to run this in a loop, until morale improves
# Create context to run "chat completion"
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
# Run inference
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
# Collect output
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
# TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
async for chunk in chat_response:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# TODO: this only works for text content
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
# Ensure we don't have any empty type field in the tool call dict.
# The OpenAI client used by providers often returns a type=None here.
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
else:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
@ -380,7 +287,127 @@ class OpenAIResponsesImpl:
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# Create response object
return output_messages
async def _store_response(
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Tool setup
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
inference_result = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
if stream:
return self._create_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
)
else:
return await self._create_non_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
)
async def _create_non_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> OpenAIResponseObject:
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response,
ctx=ctx,
tools=tools,
)
)
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
@ -393,45 +420,135 @@ class OpenAIResponsesImpl:
# Store response if requested
if store:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
await self._store_response(
response=response,
input=input,
)
if stream:
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
# TODO: response created should actually get emitted much earlier in the process
yield OpenAIResponseObjectStreamResponseCreated(response=response)
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
return async_response()
return response
async def _create_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
initial_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="in_progress",
output=output_messages.copy(),
)
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# For streaming, inference_result is an async iterator of chunks
# Stream chunks and emit delta events as they arrive
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in inference_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response_obj = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response_obj,
ctx=ctx,
tools=tools,
)
)
# Create final response
final_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="completed",
output=output_messages,
)
if store:
await self._store_response(
response=final_response,
input=input,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> tuple[
@ -441,7 +558,6 @@ class OpenAIResponsesImpl:
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseOutputMessageMCPListTools,
)
from llama_stack.apis.tools.tools import Tool

View file

@ -75,7 +75,9 @@ class PromptGuardShield:
self.temperature = temperature
self.threshold = threshold
self.device = "cuda"
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda"
# load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)

View file

@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def insert(

View file

@ -19,10 +19,10 @@ def available_providers() -> list[ProviderSpec]:
api=Api.agents,
provider_type="inline::meta-reference",
pip_packages=[
"matplotlib",
"pillow",
"pandas",
"scikit-learn",
# "matplotlib",
# "pillow",
# "pandas",
# "scikit-learn",
]
+ kvstore_dependencies(),
module="llama_stack.providers.inline.agents.meta_reference",

View file

@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.eval,
provider_type="inline::meta-reference",
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
# pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
module="llama_stack.providers.inline.eval.meta_reference",
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
api_dependencies=[

View file

@ -20,16 +20,16 @@ def available_providers() -> list[ProviderSpec]:
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[
"blobfile",
"chardet",
"pypdf",
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
# "blobfile",
# "chardet",
# "pypdf",
# "tqdm",
# "numpy",
# "scikit-learn",
# "scipy",
# "nltk",
# "sentencepiece",
# "transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from llama_stack.schema_utils import json_schema_type
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake",
description="The API token",
)
tls_verify: bool = Field(
tls_verify: bool | str = Field(
default=True,
description="Whether to verify TLS certificates",
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
@field_validator("tls_verify")
@classmethod
def validate_tls_verify(cls, v):
if isinstance(v, str):
# Check if it's a boolean string
if v.lower() in ("true", "false"):
return v.lower() == "true"
# Otherwise, treat it as a cert path
cert_path = Path(v).expanduser().resolve()
if not cert_path.exists():
raise ValueError(f"TLS certificate file does not exist: {v}")
if not cert_path.is_file():
raise ValueError(f"TLS certificate path is not a file: {v}")
return v
return v
@classmethod
def sample_run_config(
cls,

View file

@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
)
async def completion(

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BingSearchToolConfig):
self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search"
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -11,30 +11,30 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -10,8 +10,8 @@ from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => list of headers to send
mcp_headers: dict[str, list[str]] | None = None
# mcp_endpoint => dict of headers to send
mcp_headers: dict[str, dict[str, str]] | None = None
class MCPProviderConfig(BaseModel):

View file

@ -11,26 +11,33 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolGroup,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
async def initialize(self):
pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
@ -62,5 +69,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
headers.update(convert_header_list_to_dict(values))
headers.update(values)
return headers

View file

@ -12,29 +12,29 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import TavilySearchToolConfig
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: TavilySearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: WolframAlphaToolConfig):
self.config = config
self.url = "https://api.wolframalpha.com/v2/query"
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
for i, outstanding_response in enumerate(outstanding_responses):
response = await outstanding_response
i = 0
async for chunk in response:
event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
model=model,
object="chat.completion.chunk",
)
i = i + 1
async def _process_non_stream_response(
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]

View file

@ -51,16 +51,6 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
raise
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
headers = {}
for header in header_list:
parts = header.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session: