Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -114,13 +114,9 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets(
[loaded_dataset, new_dataset]
)
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"])
else:
raise NotImplementedError(
"Uploading to URL-based datasets is not supported yet"
)
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")

View file

@ -102,9 +102,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -123,9 +121,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model(**params)
chunk = next(res["body"])
@ -139,9 +135,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model_with_response_stream(**params)
event_stream = res["body"]
@ -157,14 +151,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def _get_params_for_chat_completion(
self, request: ChatCompletionRequest
) -> Dict:
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = request.model
sampling_params = request.sampling_params
@ -175,9 +165,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter)
return {
"modelId": bedrock_model,
"body": json.dumps(
@ -196,9 +184,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
assert not content_has_media(
content
), "Bedrock does not support media for embeddings"
assert not content_has_media(content), "Bedrock does not support media for embeddings"
input_text = interleaved_content_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)

View file

@ -10,9 +10,7 @@ from .config import CerebrasImplConfig
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
from .cerebras import CerebrasInferenceAdapter
assert isinstance(
config, CerebrasImplConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)

View file

@ -102,9 +102,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.completions.create(**params)
@ -149,33 +147,23 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: CompletionRequest
) -> CompletionResponse:
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
if request.sampling_params and isinstance(
request.sampling_params.strategy, TopKSamplingStrategy
):
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""

View file

@ -9,9 +9,7 @@ from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(
config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -114,9 +114,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
@ -125,17 +123,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
),
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}

View file

@ -16,9 +16,7 @@ class FireworksProviderDataValidator(BaseModel):
async def get_adapter_impl(config: FireworksImplConfig, _deps):
from .fireworks import FireworksInferenceAdapter
assert isinstance(
config, FireworksImplConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -95,9 +95,7 @@ MODEL_ALIASES = [
]
class FireworksInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
@ -147,9 +145,7 @@ class FireworksInferenceAdapter(
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter)
@ -227,9 +223,7 @@ class FireworksInferenceAdapter(
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
@ -237,9 +231,7 @@ class FireworksInferenceAdapter(
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
@ -251,34 +243,25 @@ class FireworksInferenceAdapter(
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True)
for m in request.messages
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
not media_present
), "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter
)
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
@ -289,9 +272,7 @@ class FireworksInferenceAdapter(
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(
request.sampling_params, request.response_format, request.logprobs
),
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
async def embeddings(
@ -304,9 +285,9 @@ class FireworksInferenceAdapter(
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "Fireworks does not support media for embeddings"
assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings"
)
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],

View file

@ -99,9 +99,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model_id = self.get_provider_model_id(model_id)
if model_id == "llama-3.2-3b-preview":
warnings.warn(
@ -129,9 +127,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError(
"Groq failed to call a tool", e.body.get("error", {})
) from e
raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e
else:
raise e

View file

@ -103,9 +103,7 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
elif message.role == "user":
return ChatCompletionUserMessageParam(role="user", content=message.content)
elif message.role == "assistant":
return ChatCompletionAssistantMessageParam(
role="assistant", content=message.content
)
return ChatCompletionAssistantMessageParam(role="assistant", content=message.content)
else:
raise ValueError(f"Invalid message role: {message.role}")
@ -121,10 +119,7 @@ def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
function=FunctionDefinition(
name=tool_definition.tool_name,
description=tool_definition.description,
parameters={
key: _convert_groq_tool_parameter(param)
for key, param in tool_parameters.items()
},
parameters={key: _convert_groq_tool_parameter(param) for key, param in tool_parameters.items()},
),
)
@ -148,10 +143,7 @@ def convert_chat_completion_response(
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
if choice.finish_reason == "tool_calls":
tool_calls = [
_convert_groq_tool_call(tool_call)
for tool_call in choice.message.tool_calls
]
tool_calls = [_convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
@ -221,9 +213,7 @@ async def convert_chat_completion_response_stream(
elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
)
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])

View file

@ -35,9 +35,7 @@ class NVIDIAConfig(BaseModel):
"""
url: str = Field(
default_factory=lambda: os.getenv(
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"
),
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM",
)
api_key: Optional[SecretStr] = Field(

View file

@ -96,8 +96,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
if _is_nvidia_hosted(config):
if not config.api_key:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. "
"Either provide an API key or use a self-hosted NIM."
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
)
# elif self._config.api_key:
#
@ -113,11 +112,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# make sure the client lives longer than any async calls
self._client = AsyncOpenAI(
base_url=f"{self._config.url}/v1",
api_key=(
self._config.api_key.get_secret_value()
if self._config.api_key
else "NO KEY"
),
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
@ -150,9 +145,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
try:
response = await self._client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
) from e
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
if stream:
return convert_openai_completion_stream(response)
@ -178,9 +171,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
@ -204,9 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
try:
response = await self._client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
) from e
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
if stream:
return convert_openai_chat_completion_stream(response)

View file

@ -185,9 +185,7 @@ async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessa
return content
elif isinstance(content, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(
url=await convert_image_content_to_url(content)
),
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
type="image_url",
)
elif isinstance(content, List):
@ -260,12 +258,9 @@ async def convert_chat_completion_request(
# stream -> stream
# logprobs -> logprobs
if request.response_format and not isinstance(
request.response_format, JsonSchemaResponseFormat
):
if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {request.response_format}. "
"Only JsonSchemaResponseFormat is supported."
f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported."
)
nvext = {}
@ -286,9 +281,7 @@ async def convert_chat_completion_request(
nvext.update(guided_json=request.response_format.json_schema)
if request.tools:
payload.update(
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
)
payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools])
if request.tool_choice:
payload.update(
tool_choice=request.tool_choice.value
@ -410,11 +403,7 @@ def _convert_openai_logprobs(
return None
return [
TokenLogProbs(
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
for content in logprobs.content
]
@ -452,17 +441,14 @@ def convert_openai_chat_completion_choice(
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
assert (
hasattr(choice, "message") and choice.message
), "error in server response: message not found"
assert (
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
"error in server response: finish_reason not found"
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content
or "", # CompletionMessage content is not optional
content=choice.message.content or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
),
@ -479,9 +465,7 @@ async def convert_openai_chat_completion_stream(
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> (
Generator[ChatCompletionResponseEventType, None, None]
):
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
@ -532,18 +516,14 @@ async def convert_openai_chat_completion_stream(
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest"
)
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
# NIM only produces fully formed tool calls, so we can assume success
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[
0
],
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(choice.logprobs),
@ -618,10 +598,7 @@ def convert_completion_request(
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
@ -640,9 +617,7 @@ def _convert_openai_completion_logprobs(
if not logprobs:
return None
return [
TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs
]
return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs]
def convert_openai_completion_choice(

View file

@ -16,7 +16,5 @@ class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
@classmethod
def sample_run_config(
cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs
) -> Dict[str, Any]:
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs) -> Dict[str, Any]:
return {"url": url}

View file

@ -242,9 +242,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
return await self._nonstream_chat_completion(request)
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
@ -255,14 +253,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
contents = [
await convert_message_to_openai_dict_for_ollama(m)
for m in request.messages
]
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [
item for sublist in contents for item in sublist
]
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
@ -271,12 +264,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self.formatter,
)
else:
assert (
not media_present
), "Ollama does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter
)
assert not media_present, "Ollama does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
input_dict["raw"] = True
if fmt := request.response_format:
@ -294,9 +283,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
"stream": request.stream,
}
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
@ -318,9 +305,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
@ -344,9 +329,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def embeddings(
@ -356,9 +339,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Ollama does not support media for embeddings"
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
)
response = await self.client.embed(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
@ -395,11 +378,7 @@ async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[di
if isinstance(content, ImageContentItem):
return {
"role": message.role,
"images": [
await convert_image_content_to_url(
content, download=True, include_format=False
)
],
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
}
else:
text = content.text if isinstance(content, TextContentItem) else content

View file

@ -9,9 +9,7 @@ from .runpod import RunpodInferenceAdapter
async def get_adapter_impl(config: RunpodImplConfig, _deps):
assert isinstance(
config, RunpodImplConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
impl = RunpodInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -45,9 +45,7 @@ RUNPOD_SUPPORTED_MODELS = {
class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS
)
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -104,9 +102,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
@ -115,9 +111,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:

View file

@ -15,9 +15,7 @@ class SambaNovaProviderDataValidator(BaseModel):
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
assert isinstance(
config, SambaNovaImplConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -137,9 +137,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
else:
return await self._nonstream_chat_completion(request_sambanova)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
response = self._get_client().chat.completions.create(**request)
choice = response.choices[0]
@ -147,30 +145,22 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=self.convert_to_sambanova_finish_reason(
choice.finish_reason
),
tool_calls=self.convert_to_sambanova_tool_calls(
choice.message.tool_calls
),
stop_reason=self.convert_to_sambanova_finish_reason(choice.finish_reason),
tool_calls=self.convert_to_sambanova_tool_calls(choice.message.tool_calls),
),
logprobs=None,
)
return result
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _to_async_generator():
streaming = self._get_client().chat.completions.create(**request)
for chunk in streaming:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def embeddings(
@ -180,14 +170,10 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
) -> EmbeddingsResponse:
raise NotImplementedError()
async def convert_chat_completion_request(
self, request: ChatCompletionRequest
) -> dict:
async def convert_chat_completion_request(self, request: ChatCompletionRequest) -> dict:
compatible_request = self.convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = await self.convert_to_sambanova_messages(
request.messages
)
compatible_request["messages"] = await self.convert_to_sambanova_messages(request.messages)
compatible_request["stream"] = request.stream
compatible_request["logprobs"] = False
compatible_request["extra_headers"] = {
@ -196,9 +182,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
return compatible_request
def convert_sampling_params(
self, sampling_params: SamplingParams, legacy: bool = False
) -> dict:
def convert_sampling_params(self, sampling_params: SamplingParams, legacy: bool = False) -> dict:
params = {}
if sampling_params:
@ -219,9 +203,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
return params
async def convert_to_sambanova_messages(
self, messages: List[Message]
) -> List[dict]:
async def convert_to_sambanova_messages(self, messages: List[Message]) -> List[dict]:
conversation = []
for message in messages:
content = {}

View file

@ -74,9 +74,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
async def shutdown(self) -> None:
@ -150,17 +148,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
return options
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = await completion_request_to_prompt_model_input_info(
request, self.formatter
)
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
@ -176,9 +170,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if chunk.details:
finish_reason = chunk.details.finish_reason
choice = OpenAICompatCompletionChoice(
text=token_result.text, finish_reason=finish_reason
)
choice = OpenAICompatCompletionChoice(text=token_result.text, finish_reason=finish_reason)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
@ -232,9 +224,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.client.text_generation(**params)
@ -247,9 +237,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
@ -263,9 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
@ -276,9 +262,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
@ -304,9 +288,7 @@ class TGIAdapter(_HfAdapter):
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(
model=config.huggingface_repo, token=config.api_token.get_secret_value()
)
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
@ -324,6 +306,4 @@ class InferenceEndpointAdapter(_HfAdapter):
# Initialize the adapter
self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
)
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])

View file

@ -16,9 +16,7 @@ class TogetherProviderDataValidator(BaseModel):
async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter
assert isinstance(
config, TogetherImplConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -90,9 +90,7 @@ MODEL_ALIASES = [
]
class TogetherInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
@ -140,9 +138,7 @@ class TogetherInferenceAdapter(
together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key)
async def _nonstream_completion(
self, request: CompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)
@ -217,9 +213,7 @@ class TogetherInferenceAdapter(
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = self._get_client().chat.completions.create(**params)
@ -227,9 +221,7 @@ class TogetherInferenceAdapter(
r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
@ -242,40 +234,28 @@ class TogetherInferenceAdapter(
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_openai_dict(m) for m in request.messages
]
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
not media_present
), "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter
)
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(
request.sampling_params, request.logprobs, request.response_format
),
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
async def embeddings(
@ -284,9 +264,9 @@ class TogetherInferenceAdapter(
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Together does not support media for embeddings"
assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings"
)
r = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],

View file

@ -10,9 +10,7 @@ from .config import VLLMInferenceAdapterConfig
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter
assert isinstance(
config, VLLMInferenceAdapterConfig
), f"Unexpected config type: {type(config)}"
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -147,9 +147,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
r = client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = await self._get_params(request)
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
@ -163,14 +161,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
async for chunk in process_chat_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
@ -199,9 +193,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return model
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
@ -211,8 +203,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True)
for m in request.messages
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
@ -221,9 +212,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.formatter,
)
else:
assert (
not media_present
), "vLLM does not support media for Completion requests"
assert not media_present, "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(
request,
self.formatter,
@ -231,9 +220,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
input_dict["extra_body"] = {
"guided_json": request.response_format.json_schema
}
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
@ -257,9 +244,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
assert model.model_type == ModelType.embedding
assert model.metadata.get("embedding_dimensions")
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "VLLM does not support media for embeddings"
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],

View file

@ -83,9 +83,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = self.bedrock_runtime_client.apply_guardrail(
guardrailIdentifier=shield.provider_resource_id,

View file

@ -23,9 +23,7 @@ from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BingSearchToolConfig):
self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search"
@ -67,9 +65,7 @@ class BingSearchToolRuntimeImpl(
)
]
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
"Ocp-Apim-Subscription-Key": api_key,
@ -88,9 +84,7 @@ class BingSearchToolRuntimeImpl(
)
response.raise_for_status()
return ToolInvocationResult(
content=json.dumps(self._clean_response(response.json()))
)
return ToolInvocationResult(content=json.dumps(self._clean_response(response.json())))
def _clean_response(self, search_response):
clean_response = []
@ -99,9 +93,7 @@ class BingSearchToolRuntimeImpl(
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
clean_response.append({k: v for k, v in p.items() if k in selected_keys})
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]

View file

@ -23,9 +23,7 @@ from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
@ -67,9 +65,7 @@ class BraveSearchToolRuntimeImpl(
)
]
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
@ -135,10 +131,7 @@ class BraveSearchToolRuntimeImpl(
results = result_selector(results)
if isinstance(results, list):
cleaned = [
{k: v for k, v in item.items() if k in selected_keys}
for item in results
]
cleaned = [{k: v for k, v in item.items() if k in selected_keys} for item in results]
else:
cleaned = {k: v for k, v in results.items() if k in selected_keys}

View file

@ -42,9 +42,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get(
"properties", {}
).items():
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
@ -64,9 +62,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
)
return tools
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")

View file

@ -23,9 +23,7 @@ from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import TavilySearchToolConfig
class TavilySearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: TavilySearchToolConfig):
self.config = config
@ -66,18 +64,14 @@ class TavilySearchToolRuntimeImpl(
)
]
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": api_key, "query": kwargs["query"]},
)
return ToolInvocationResult(
content=json.dumps(self._clean_tavily_response(response.json()))
)
return ToolInvocationResult(content=json.dumps(self._clean_tavily_response(response.json())))
def _clean_tavily_response(self, search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}

View file

@ -23,9 +23,7 @@ from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: WolframAlphaToolConfig):
self.config = config
self.url = "https://api.wolframalpha.com/v2/query"
@ -67,9 +65,7 @@ class WolframAlphaToolRuntimeImpl(
)
]
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()
params = {
"input": kwargs["query"],
@ -82,9 +78,7 @@ class WolframAlphaToolRuntimeImpl(
params=params,
)
return ToolInvocationResult(
content=json.dumps(self._clean_wolfram_alpha_response(response.json()))
)
return ToolInvocationResult(content=json.dumps(self._clean_wolfram_alpha_response(response.json())))
def _clean_wolfram_alpha_response(self, wa_response):
remove = {
@ -128,10 +122,7 @@ class WolframAlphaToolRuntimeImpl(
for sub_key in key_to_remove:
if sub_key == "pods":
for i in range(len(wa_response[main_key][sub_key])):
if (
wa_response[main_key][sub_key][i]["title"]
== "Result"
):
if wa_response[main_key][sub_key][i]["title"] == "Result":
del wa_response[main_key][sub_key][i + 1 :]
break
sub_items = wa_response[main_key][sub_key]

View file

@ -11,9 +11,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaRemoteImplConfig
async def get_adapter_impl(
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
):
async def get_adapter_impl(config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference])

View file

@ -42,9 +42,9 @@ class ChromaIndex(EmbeddingIndex):
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
ids = [f"{c.metadata['document_id']}:chunk-{i}" for i, c in enumerate(chunks)]
await maybe_await(
@ -55,9 +55,7 @@ class ChromaIndex(EmbeddingIndex):
)
)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryChunksResponse:
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
@ -109,9 +107,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.client = await chromadb.AsyncHttpClient(
host=parsed.hostname, port=parsed.port
)
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
@ -157,9 +153,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> VectorDBWithIndex:
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
@ -169,8 +163,6 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
collection = await maybe_await(self.client.get_collection(vector_db_id))
if not collection:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
index = VectorDBWithIndex(
vector_db, ChromaIndex(self.client, collection), self.inference_api
)
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index
return index

View file

@ -71,9 +71,9 @@ class PGVectorIndex(EmbeddingIndex):
)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
values = []
for i, chunk in enumerate(chunks):
@ -94,9 +94,7 @@ class PGVectorIndex(EmbeddingIndex):
)
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryChunksResponse:
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
@ -166,9 +164,7 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
upsert_models(self.cursor, [(vector_db.identifier, vector_db)])
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db, index, self.inference_api
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[vector_db_id].index.delete()
@ -192,15 +188,11 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
index = await self._get_and_cache_vector_db_index(vector_db_id)
return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> VectorDBWithIndex:
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
self.cache[vector_db_id] = VectorDBWithIndex(
vector_db, index, self.inference_api
)
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]

View file

@ -43,16 +43,14 @@ class QdrantIndex(EmbeddingIndex):
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await self.client.collection_exists(self.collection_name):
await self.client.create_collection(
self.collection_name,
vectors_config=models.VectorParams(
size=len(embeddings[0]), distance=models.Distance.COSINE
),
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE),
)
points = []
@ -62,16 +60,13 @@ class QdrantIndex(EmbeddingIndex):
PointStruct(
id=convert_id(chunk_id),
vector=embedding,
payload={"chunk_content": chunk.model_dump()}
| {CHUNK_ID_KEY: chunk_id},
payload={"chunk_content": chunk.model_dump()} | {CHUNK_ID_KEY: chunk_id},
)
)
await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryChunksResponse:
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,
@ -124,9 +119,7 @@ class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
self.cache[vector_db.identifier] = index
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> Optional[VectorDBWithIndex]:
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]

View file

@ -35,9 +35,9 @@ class WeaviateIndex(EmbeddingIndex):
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
data_objects = []
for i, chunk in enumerate(chunks):
@ -56,9 +56,7 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryChunksResponse:
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
@ -85,9 +83,7 @@ class WeaviateIndex(EmbeddingIndex):
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(
where=Filter.by_property("id").contains_any(chunk_ids)
)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
class WeaviateMemoryAdapter(
@ -149,9 +145,7 @@ class WeaviateMemoryAdapter(
self.inference_api,
)
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> Optional[VectorDBWithIndex]:
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]