forked from phoenix-oss/llama-stack-mirror
chore: more mypy checks (ollama, vllm, ...) (#1777)
# What does this PR do? - **chore: mypy for strong_typing** - **chore: mypy for remote::vllm** - **chore: mypy for remote::ollama** - **chore: mypy for providers.datatype** --------- Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
d5e0f32485
commit
66d6c2580e
15 changed files with 103 additions and 72 deletions
|
@ -21,7 +21,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
async def register_model(self, model: Model) -> None: ...
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
||||
|
@ -113,8 +113,7 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: Optional[str] = Field(
|
||||
default=None,
|
||||
config_class: str = Field(
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
|
@ -162,7 +161,8 @@ class RemoteProviderConfig(BaseModel):
|
|||
@classmethod
|
||||
def from_url(cls, url: str) -> "RemoteProviderConfig":
|
||||
parsed = urlparse(url)
|
||||
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
|
||||
attrs = {k: v for k, v in parsed._asdict().items() if v is not None}
|
||||
return cls(**attrs)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -43,7 +43,7 @@ class SentenceTransformersInferenceImpl(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
|
@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import (
|
|||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
@ -86,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -94,10 +104,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
|
@ -111,7 +121,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
|
@ -129,7 +141,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
|
||||
|
@ -148,17 +160,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
|
@ -181,7 +193,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if sampling_options.get("max_tokens") is not None:
|
||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||
|
||||
input_dict = {}
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
|
@ -201,9 +213,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["raw"] = True
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == "json_schema":
|
||||
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||
input_dict["format"] = fmt.json_schema
|
||||
elif fmt.type == "grammar":
|
||||
elif isinstance(fmt, GrammarResponseFormat):
|
||||
raise NotImplementedError("Grammar response format is not supported")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||
|
@ -240,7 +252,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
|
@ -275,7 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Ollama does not support media for embeddings"
|
||||
|
|
|
@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.provider_resource_id != self.model_id:
|
||||
raise ValueError(
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
@ -32,11 +32,12 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
|
@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response(
|
|||
|
||||
|
||||
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
||||
if tools is None:
|
||||
return tools
|
||||
|
||||
compat_tools = []
|
||||
|
||||
for tool in tools:
|
||||
|
@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
|
|||
|
||||
compat_tools.append(compat_tool)
|
||||
|
||||
if len(compat_tools) > 0:
|
||||
return compat_tools
|
||||
return None
|
||||
return compat_tools
|
||||
|
||||
|
||||
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
||||
|
@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
)
|
||||
elif choice.delta.tool_calls:
|
||||
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||
tool_call_buf.tool_name += tool_call.tool_name
|
||||
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||
tool_call_buf.call_id += tool_call.call_id
|
||||
tool_call_buf.arguments += tool_call.arguments
|
||||
# TODO: remove str() when dict type for 'arguments' is no longer allowed
|
||||
tool_call_buf.arguments += str(tool_call.arguments)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -240,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -248,10 +250,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
|
@ -270,17 +272,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
|
@ -318,11 +320,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await client.chat.completions.create(**params)
|
||||
if len(request.tools) > 0:
|
||||
if request.tools:
|
||||
res = _process_vllm_chat_completion_stream_response(stream)
|
||||
else:
|
||||
res = process_chat_completion_stream_response(stream, request)
|
||||
|
@ -330,11 +334,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
@ -342,6 +350,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
assert self.client is not None
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = await self.client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
|
@ -357,7 +366,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if "max_tokens" not in options:
|
||||
options["max_tokens"] = self.config.max_tokens
|
||||
|
||||
input_dict = {}
|
||||
input_dict: dict[str, Any] = {}
|
||||
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
||||
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||
|
||||
|
@ -368,9 +377,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
|
||||
elif isinstance(fmt, GrammarResponseFormat):
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
@ -393,7 +402,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding
|
||||
|
|
|
@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
|||
return options
|
||||
|
||||
|
||||
def get_sampling_options(params: SamplingParams) -> dict:
|
||||
def get_sampling_options(params: SamplingParams | None) -> dict:
|
||||
if not params:
|
||||
return {}
|
||||
|
||||
options = {}
|
||||
if params:
|
||||
options.update(get_sampling_strategy_options(params))
|
||||
|
@ -297,7 +300,7 @@ def process_chat_completion_response(
|
|||
|
||||
async def process_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
|
@ -334,7 +337,7 @@ async def process_completion_stream_response(
|
|||
async def process_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue