mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 01:26:10 +00:00
feat: small ollama package
This commit is contained in:
commit
2d5d05a2b4
103 changed files with 7262 additions and 7422 deletions
|
@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
|
|||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import Tool
|
||||
from llama_stack.apis.tools import ToolGroup
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
|
|||
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
||||
|
||||
|
||||
class ToolsProtocolPrivate(Protocol):
|
||||
async def register_tool(self, tool: Tool) -> None: ...
|
||||
class ToolGroupsProtocolPrivate(Protocol):
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None: ...
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, cast
|
||||
|
@ -29,10 +30,12 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
|
@ -255,110 +258,14 @@ class OpenAIResponsesImpl:
|
|||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
async def create_openai_response(
|
||||
async def _process_response_choices(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
):
|
||||
chat_response: OpenAIChatCompletion,
|
||||
ctx: ChatCompletionContext,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
"""Handle tool execution and response message creation."""
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
stream = False if stream is None else stream
|
||||
|
||||
# Huge TODO: we need to run this in a loop, until morale improves
|
||||
|
||||
# Create context to run "chat completion"
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||
)
|
||||
if mcp_list_message:
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Run inference
|
||||
chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Collect output
|
||||
if stream:
|
||||
# TODO: refactor this into a separate method that handles streaming
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
async for chunk in chat_response:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# TODO: this only works for text content
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
response_tool_call.function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
# Ensure we don't have any empty type field in the tool call dict.
|
||||
# The OpenAI client used by providers often returns a type=None here.
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
else:
|
||||
# dump and reload to map to our pydantic types
|
||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
# Execute tool calls if any
|
||||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
|
@ -380,7 +287,127 @@ class OpenAIResponsesImpl:
|
|||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
|
||||
# Create response object
|
||||
return output_messages
|
||||
|
||||
async def _store_response(
|
||||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Tool setup
|
||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||
)
|
||||
if mcp_list_message:
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
inference_result = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._create_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
input=input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
return await self._create_non_streaming_response(
|
||||
inference_result=inference_result,
|
||||
ctx=ctx,
|
||||
output_messages=output_messages,
|
||||
input=input,
|
||||
model=model,
|
||||
store=store,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
async def _create_non_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> OpenAIResponseObject:
|
||||
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
|
@ -393,45 +420,135 @@ class OpenAIResponsesImpl:
|
|||
|
||||
# Store response if requested
|
||||
if store:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
await self._store_response(
|
||||
response=response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
||||
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# TODO: response created should actually get emitted much earlier in the process
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=response)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
|
||||
|
||||
return async_response()
|
||||
|
||||
return response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
inference_result: Any,
|
||||
ctx: ChatCompletionContext,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
store: bool | None,
|
||||
tools: list[OpenAIResponseInputTool] | None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Create initial response and emit response.created immediately
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
initial_response = OpenAIResponseObject(
|
||||
created_at=created_at,
|
||||
id=response_id,
|
||||
model=model,
|
||||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
)
|
||||
|
||||
# Emit response.created immediately
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# For streaming, inference_result is an async iterator of chunks
|
||||
# Stream chunks and emit delta events as they arrive
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
sequence_number = 0
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
|
||||
async for chunk in inference_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
response_tool_call.function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert collected chunks to complete response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response_obj = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
|
||||
# Process response choices (tool execution and message creation)
|
||||
output_messages.extend(
|
||||
await self._process_response_choices(
|
||||
chat_response=chat_response_obj,
|
||||
ctx=ctx,
|
||||
tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=created_at,
|
||||
id=response_id,
|
||||
model=model,
|
||||
object="response",
|
||||
status="completed",
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
if store:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> tuple[
|
||||
|
@ -441,7 +558,6 @@ class OpenAIResponsesImpl:
|
|||
]:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool
|
||||
|
||||
|
|
|
@ -75,7 +75,9 @@ class PromptGuardShield:
|
|||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
|
||||
self.device = "cuda"
|
||||
self.device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
|
||||
# load model and tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
|
|
|
@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
|
|||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
|
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
|
|||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: RagToolRuntimeConfig,
|
||||
|
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
async def insert(
|
||||
|
|
|
@ -19,10 +19,10 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.agents,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[
|
||||
"matplotlib",
|
||||
"pillow",
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
# "matplotlib",
|
||||
# "pillow",
|
||||
# "pandas",
|
||||
# "scikit-learn",
|
||||
]
|
||||
+ kvstore_dependencies(),
|
||||
module="llama_stack.providers.inline.agents.meta_reference",
|
||||
|
|
|
@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.eval,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
|
||||
# pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
|
||||
module="llama_stack.providers.inline.eval.meta_reference",
|
||||
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -20,16 +20,16 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
pip_packages=[
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"pypdf",
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"nltk",
|
||||
"sentencepiece",
|
||||
"transformers",
|
||||
# "blobfile",
|
||||
# "chardet",
|
||||
# "pypdf",
|
||||
# "tqdm",
|
||||
# "numpy",
|
||||
# "scikit-learn",
|
||||
# "scipy",
|
||||
# "nltk",
|
||||
# "sentencepiece",
|
||||
# "transformers",
|
||||
],
|
||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
|
|
|
@ -4,8 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default="fake",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool = Field(
|
||||
tls_verify: bool | str = Field(
|
||||
default=True,
|
||||
description="Whether to verify TLS certificates",
|
||||
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
def validate_tls_verify(cls, v):
|
||||
if isinstance(v, str):
|
||||
# Check if it's a boolean string
|
||||
if v.lower() in ("true", "false"):
|
||||
return v.lower() == "true"
|
||||
# Otherwise, treat it as a cert path
|
||||
cert_path = Path(v).expanduser().resolve()
|
||||
if not cert_path.exists():
|
||||
raise ValueError(f"TLS certificate file does not exist: {v}")
|
||||
if not cert_path.is_file():
|
||||
raise ValueError(f"TLS certificate path is not a file: {v}")
|
||||
return v
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
|
|
|
@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
|
||||
)
|
||||
|
||||
async def completion(
|
||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import BingSearchToolConfig
|
||||
|
||||
|
||||
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: BingSearchToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -11,30 +11,30 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import BraveSearchToolConfig
|
||||
|
||||
|
||||
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: BraveSearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -10,8 +10,8 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class MCPProviderDataValidator(BaseModel):
|
||||
# mcp_endpoint => list of headers to send
|
||||
mcp_headers: dict[str, list[str]] | None = None
|
||||
# mcp_endpoint => dict of headers to send
|
||||
mcp_headers: dict[str, dict[str, str]] | None = None
|
||||
|
||||
|
||||
class MCPProviderConfig(BaseModel):
|
||||
|
|
|
@ -11,26 +11,33 @@ from llama_stack.apis.common.content_types import URL
|
|||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||
|
||||
from .config import MCPProviderConfig
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
|
@ -62,5 +69,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
|||
for uri, values in provider_data.mcp_headers.items():
|
||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
headers.update(convert_header_list_to_dict(values))
|
||||
headers.update(values)
|
||||
return headers
|
||||
|
|
|
@ -12,29 +12,29 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import TavilySearchToolConfig
|
||||
|
||||
|
||||
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: TavilySearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
|
||||
from .config import WolframAlphaToolConfig
|
||||
|
||||
|
||||
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: WolframAlphaToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
|
|
|
@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||
):
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
for outstanding_response in outstanding_responses:
|
||||
for i, outstanding_response in enumerate(outstanding_responses):
|
||||
response = await outstanding_response
|
||||
i = 0
|
||||
async for chunk in response:
|
||||
event = chunk.event
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||
|
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
i = i + 1
|
||||
|
||||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||
|
|
|
@ -51,16 +51,6 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
|||
raise
|
||||
|
||||
|
||||
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
|
||||
headers = {}
|
||||
for header in header_list:
|
||||
parts = header.split(":")
|
||||
if len(parts) == 2:
|
||||
k, v = parts
|
||||
headers[k.strip()] = v.strip()
|
||||
return headers
|
||||
|
||||
|
||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||
tools = []
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue