mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: more mypy checks (ollama, vllm, ...) (#1777)
# What does this PR do? - **chore: mypy for strong_typing** - **chore: mypy for remote::vllm** - **chore: mypy for remote::ollama** - **chore: mypy for providers.datatype** --------- Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
d5e0f32485
commit
66d6c2580e
15 changed files with 103 additions and 72 deletions
|
@ -394,7 +394,7 @@ class EmbeddingsResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ModelStore(Protocol):
|
class ModelStore(Protocol):
|
||||||
def get_model(self, identifier: str) -> Model: ...
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
|
||||||
class TextTruncation(Enum):
|
class TextTruncation(Enum):
|
||||||
|
@ -431,7 +431,7 @@ class Inference(Protocol):
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_store: ModelStore
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
@webmethod(route="/inference/completion", method="POST")
|
@webmethod(route="/inference/completion", method="POST")
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class ModelsProtocolPrivate(Protocol):
|
class ModelsProtocolPrivate(Protocol):
|
||||||
async def register_model(self, model: Model) -> None: ...
|
async def register_model(self, model: Model) -> Model: ...
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
||||||
|
@ -113,8 +113,7 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="The pip dependencies needed for this implementation",
|
description="The pip dependencies needed for this implementation",
|
||||||
)
|
)
|
||||||
config_class: Optional[str] = Field(
|
config_class: str = Field(
|
||||||
default=None,
|
|
||||||
description="Fully-qualified classname of the config for this provider",
|
description="Fully-qualified classname of the config for this provider",
|
||||||
)
|
)
|
||||||
provider_data_validator: Optional[str] = Field(
|
provider_data_validator: Optional[str] = Field(
|
||||||
|
@ -162,7 +161,8 @@ class RemoteProviderConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_url(cls, url: str) -> "RemoteProviderConfig":
|
def from_url(cls, url: str) -> "RemoteProviderConfig":
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
|
attrs = {k: v for k, v in parsed._asdict().items() if v is not None}
|
||||||
|
return cls(**attrs)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -43,7 +43,7 @@ class SentenceTransformersInferenceImpl(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import (
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
|
GrammarResponseFormat,
|
||||||
Inference,
|
Inference,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -86,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _get_model(self, model_id: str) -> Model:
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store not set")
|
||||||
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -94,10 +104,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
@ -111,7 +121,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
@ -129,7 +141,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async for chunk in process_completion_stream_response(stream):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.generate(**params)
|
r = await self.client.generate(**params)
|
||||||
|
|
||||||
|
@ -148,17 +160,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -181,7 +193,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if sampling_options.get("max_tokens") is not None:
|
if sampling_options.get("max_tokens") is not None:
|
||||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||||
|
|
||||||
input_dict = {}
|
input_dict: dict[str, Any] = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.register_helper.get_llama_model(request.model)
|
llama_model = self.register_helper.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
@ -201,9 +213,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
|
|
||||||
if fmt := request.response_format:
|
if fmt := request.response_format:
|
||||||
if fmt.type == "json_schema":
|
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||||
input_dict["format"] = fmt.json_schema
|
input_dict["format"] = fmt.json_schema
|
||||||
elif fmt.type == "grammar":
|
elif isinstance(fmt, GrammarResponseFormat):
|
||||||
raise NotImplementedError("Grammar response format is not supported")
|
raise NotImplementedError("Grammar response format is not supported")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||||
|
@ -240,7 +252,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
return process_chat_completion_response(response, request)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
@ -275,7 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Ollama does not support media for embeddings"
|
"Ollama does not support media for embeddings"
|
||||||
|
|
|
@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
if model.provider_resource_id != self.model_id:
|
if model.provider_resource_id != self.model_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
@ -32,11 +32,12 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
|
GrammarResponseFormat,
|
||||||
Inference,
|
Inference,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response(
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
||||||
if tools is None:
|
|
||||||
return tools
|
|
||||||
|
|
||||||
compat_tools = []
|
compat_tools = []
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
|
||||||
|
|
||||||
compat_tools.append(compat_tool)
|
compat_tools.append(compat_tool)
|
||||||
|
|
||||||
if len(compat_tools) > 0:
|
return compat_tools
|
||||||
return compat_tools
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
||||||
|
@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
elif choice.delta.tool_calls:
|
elif choice.delta.tool_calls:
|
||||||
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||||
tool_call_buf.tool_name += tool_call.tool_name
|
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||||
tool_call_buf.call_id += tool_call.call_id
|
tool_call_buf.call_id += tool_call.call_id
|
||||||
tool_call_buf.arguments += tool_call.arguments
|
# TODO: remove str() when dict type for 'arguments' is no longer allowed
|
||||||
|
tool_call_buf.arguments += str(tool_call.arguments)
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -240,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _get_model(self, model_id: str) -> Model:
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store not set")
|
||||||
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -248,10 +250,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
@ -270,17 +272,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||||
# References:
|
# References:
|
||||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||||
|
@ -318,11 +320,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||||
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await client.chat.completions.create(**params)
|
stream = await client.chat.completions.create(**params)
|
||||||
if len(request.tools) > 0:
|
if request.tools:
|
||||||
res = _process_vllm_chat_completion_stream_response(stream)
|
res = _process_vllm_chat_completion_stream_response(stream)
|
||||||
else:
|
else:
|
||||||
res = process_chat_completion_stream_response(stream, request)
|
res = process_chat_completion_stream_response(stream, request)
|
||||||
|
@ -330,11 +334,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
|
assert self.client is not None
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.completions.create(**params)
|
r = await self.client.completions.create(**params)
|
||||||
return process_completion_response(r)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
|
assert self.client is not None
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self.client.completions.create(**params)
|
stream = await self.client.completions.create(**params)
|
||||||
|
@ -342,6 +350,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
assert self.client is not None
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
res = await self.client.models.list()
|
res = await self.client.models.list()
|
||||||
available_models = [m.id async for m in res]
|
available_models = [m.id async for m in res]
|
||||||
|
@ -357,7 +366,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if "max_tokens" not in options:
|
if "max_tokens" not in options:
|
||||||
options["max_tokens"] = self.config.max_tokens
|
options["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
input_dict = {}
|
input_dict: dict[str, Any] = {}
|
||||||
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
||||||
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||||
|
|
||||||
|
@ -368,9 +377,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
||||||
if fmt := request.response_format:
|
if fmt := request.response_format:
|
||||||
if fmt.type == ResponseFormatType.json_schema.value:
|
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||||
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
|
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
|
||||||
elif fmt.type == ResponseFormatType.grammar.value:
|
elif isinstance(fmt, GrammarResponseFormat):
|
||||||
raise NotImplementedError("Grammar response format not supported yet")
|
raise NotImplementedError("Grammar response format not supported yet")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
@ -393,7 +402,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
assert self.client is not None
|
||||||
|
model = await self._get_model(model_id)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
assert model.model_type == ModelType.embedding
|
assert model.model_type == ModelType.embedding
|
||||||
|
|
|
@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
|
@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
def get_sampling_options(params: SamplingParams) -> dict:
|
def get_sampling_options(params: SamplingParams | None) -> dict:
|
||||||
|
if not params:
|
||||||
|
return {}
|
||||||
|
|
||||||
options = {}
|
options = {}
|
||||||
if params:
|
if params:
|
||||||
options.update(get_sampling_strategy_options(params))
|
options.update(get_sampling_strategy_options(params))
|
||||||
|
@ -297,7 +300,7 @@ def process_chat_completion_response(
|
||||||
|
|
||||||
async def process_completion_stream_response(
|
async def process_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
@ -334,7 +337,7 @@ async def process_completion_stream_response(
|
||||||
async def process_chat_completion_stream_response(
|
async def process_chat_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
|
|
@ -77,7 +77,9 @@ def typeannotation(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap(cls: Type[T]) -> Type[T]:
|
def wrap(cls: Type[T]) -> Type[T]:
|
||||||
cls.__repr__ = _compact_dataclass_repr
|
# mypy fails to equate bound-y functions (first argument interpreted as
|
||||||
|
# the bound object) with class methods, hence the `ignore` directive.
|
||||||
|
cls.__repr__ = _compact_dataclass_repr # type: ignore[method-assign]
|
||||||
if not dataclasses.is_dataclass(cls):
|
if not dataclasses.is_dataclass(cls):
|
||||||
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
||||||
cls,
|
cls,
|
||||||
|
|
|
@ -627,7 +627,8 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
|
||||||
super().assign(property_parsers)
|
super().assign(property_parsers)
|
||||||
|
|
||||||
def create(self, **field_values: Any) -> NamedTuple:
|
def create(self, **field_values: Any) -> NamedTuple:
|
||||||
return self.class_type(**field_values)
|
# mypy fails to deduce that this class returns NamedTuples only, hence the `ignore` directive
|
||||||
|
return self.class_type(**field_values) # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
class DataclassDeserializer(ClassDeserializer[T]):
|
class DataclassDeserializer(ClassDeserializer[T]):
|
||||||
|
|
|
@ -48,7 +48,7 @@ class DocstringParam:
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
param_type: type = inspect.Signature.empty
|
param_type: type | str = inspect.Signature.empty
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f":param {self.name}: {self.description}"
|
return f":param {self.name}: {self.description}"
|
||||||
|
|
|
@ -260,7 +260,8 @@ def extend_enum(
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
values.update((e.name, e.value) for e in source)
|
values.update((e.name, e.value) for e in source)
|
||||||
values.update((e.name, e.value) for e in extend)
|
values.update((e.name, e.value) for e in extend)
|
||||||
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore
|
# mypy fails to determine that __name__ is always a string; hence the `ignore` directive.
|
||||||
|
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
|
||||||
|
|
||||||
# assign the newly created type to the same module where the extending class is defined
|
# assign the newly created type to the same module where the extending class is defined
|
||||||
enum_class.__module__ = extend.__module__
|
enum_class.__module__ = extend.__module__
|
||||||
|
@ -327,9 +328,7 @@ def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
|
||||||
raise TypeError("optional type must have un-subscripted type of Union")
|
raise TypeError("optional type must have un-subscripted type of Union")
|
||||||
|
|
||||||
# will automatically unwrap Union[T] into T
|
# will automatically unwrap Union[T] into T
|
||||||
return Union[
|
return Union[tuple(filter(lambda item: item is not type(None), typing.get_args(typ)))] # type: ignore[return-value]
|
||||||
tuple(filter(lambda item: item is not type(None), typing.get_args(typ))) # type: ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_type_union(typ: object) -> bool:
|
def is_type_union(typ: object) -> bool:
|
||||||
|
@ -431,7 +430,7 @@ def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
|
||||||
"Extracts the item type of a list type (e.g. returns `T` for `List[T]`)."
|
"Extracts the item type of a list type (e.g. returns `T` for `List[T]`)."
|
||||||
|
|
||||||
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
||||||
return list_type
|
return list_type # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
def is_generic_set(typ: object) -> TypeGuard[Type[set]]:
|
def is_generic_set(typ: object) -> TypeGuard[Type[set]]:
|
||||||
|
@ -456,7 +455,7 @@ def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
|
||||||
"Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)."
|
"Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)."
|
||||||
|
|
||||||
(set_type,) = typing.get_args(typ) # unpack single tuple element
|
(set_type,) = typing.get_args(typ) # unpack single tuple element
|
||||||
return set_type
|
return set_type # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]:
|
def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]:
|
||||||
|
@ -513,7 +512,7 @@ def unwrap_annotated_type(typ: T) -> T:
|
||||||
|
|
||||||
if is_type_annotated(typ):
|
if is_type_annotated(typ):
|
||||||
# type is Annotated[T, ...]
|
# type is Annotated[T, ...]
|
||||||
return typing.get_args(typ)[0]
|
return typing.get_args(typ)[0] # type: ignore[no-any-return]
|
||||||
else:
|
else:
|
||||||
# type is a regular type
|
# type is a regular type
|
||||||
return typ
|
return typ
|
||||||
|
@ -538,7 +537,7 @@ def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S])
|
||||||
transformed_type = transform(inner_type)
|
transformed_type = transform(inner_type)
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
return Annotated[(transformed_type, *metadata)] # type: ignore
|
return Annotated[(transformed_type, *metadata)] # type: ignore[return-value]
|
||||||
else:
|
else:
|
||||||
return transformed_type
|
return transformed_type
|
||||||
|
|
||||||
|
@ -563,7 +562,7 @@ else:
|
||||||
return typing.get_type_hints(typ)
|
return typing.get_type_hints(typ)
|
||||||
|
|
||||||
|
|
||||||
def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]:
|
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
|
||||||
"Returns all properties of a class."
|
"Returns all properties of a class."
|
||||||
|
|
||||||
if is_dataclass_type(typ):
|
if is_dataclass_type(typ):
|
||||||
|
@ -573,7 +572,7 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]:
|
||||||
return resolved_hints.items()
|
return resolved_hints.items()
|
||||||
|
|
||||||
|
|
||||||
def get_class_property(typ: type, name: str) -> Optional[type]:
|
def get_class_property(typ: type, name: str) -> Optional[type | str]:
|
||||||
"Looks up the annotated type of a property in a class by its property name."
|
"Looks up the annotated type of a property in a class by its property name."
|
||||||
|
|
||||||
for property_name, property_type in get_class_properties(typ):
|
for property_name, property_type in get_class_properties(typ):
|
||||||
|
|
|
@ -460,13 +460,17 @@ class JsonSchemaGenerator:
|
||||||
discriminator = None
|
discriminator = None
|
||||||
if typing.get_origin(data_type) is Annotated:
|
if typing.get_origin(data_type) is Annotated:
|
||||||
discriminator = typing.get_args(data_type)[1].discriminator
|
discriminator = typing.get_args(data_type)[1].discriminator
|
||||||
ret = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]}
|
ret: Schema = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]}
|
||||||
if discriminator:
|
if discriminator:
|
||||||
# for each union type, we need to read the value of the discriminator
|
# for each union type, we need to read the value of the discriminator
|
||||||
mapping = {}
|
mapping: dict[str, JsonType] = {}
|
||||||
for union_type in typing.get_args(typ):
|
for union_type in typing.get_args(typ):
|
||||||
props = self.type_to_schema(union_type, force_expand=True)["properties"]
|
props = self.type_to_schema(union_type, force_expand=True)["properties"]
|
||||||
mapping[props[discriminator]["default"]] = self.type_to_schema(union_type)["$ref"]
|
# mypy is confused here because JsonType allows multiple types, some of them
|
||||||
|
# not indexable (bool?) or not indexable by string (list?). The correctness of
|
||||||
|
# types depends on correct model definitions. Hence multiple ignore statements below.
|
||||||
|
discriminator_value = props[discriminator]["default"] # type: ignore[index,call-overload]
|
||||||
|
mapping[discriminator_value] = self.type_to_schema(union_type)["$ref"] # type: ignore[index]
|
||||||
|
|
||||||
ret["discriminator"] = {
|
ret["discriminator"] = {
|
||||||
"propertyName": discriminator,
|
"propertyName": discriminator,
|
||||||
|
|
|
@ -134,7 +134,10 @@ class IPv6Serializer(Serializer[ipaddress.IPv6Address]):
|
||||||
|
|
||||||
class EnumSerializer(Serializer[enum.Enum]):
|
class EnumSerializer(Serializer[enum.Enum]):
|
||||||
def generate(self, obj: enum.Enum) -> Union[int, str]:
|
def generate(self, obj: enum.Enum) -> Union[int, str]:
|
||||||
return obj.value
|
value = obj.value
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
class UntypedListSerializer(Serializer[list]):
|
class UntypedListSerializer(Serializer[list]):
|
||||||
|
|
|
@ -214,7 +214,6 @@ exclude = [
|
||||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||||
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
|
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
|
||||||
"^llama_stack/models/llama/sku_list\\.py$",
|
"^llama_stack/models/llama/sku_list\\.py$",
|
||||||
"^llama_stack/providers/datatypes\\.py$",
|
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||||
|
@ -248,7 +247,6 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/inference/gemini/",
|
"^llama_stack/providers/remote/inference/gemini/",
|
||||||
"^llama_stack/providers/remote/inference/groq/",
|
"^llama_stack/providers/remote/inference/groq/",
|
||||||
"^llama_stack/providers/remote/inference/nvidia/",
|
"^llama_stack/providers/remote/inference/nvidia/",
|
||||||
"^llama_stack/providers/remote/inference/ollama/",
|
|
||||||
"^llama_stack/providers/remote/inference/openai/",
|
"^llama_stack/providers/remote/inference/openai/",
|
||||||
"^llama_stack/providers/remote/inference/passthrough/",
|
"^llama_stack/providers/remote/inference/passthrough/",
|
||||||
"^llama_stack/providers/remote/inference/runpod/",
|
"^llama_stack/providers/remote/inference/runpod/",
|
||||||
|
@ -256,7 +254,6 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/inference/sample/",
|
"^llama_stack/providers/remote/inference/sample/",
|
||||||
"^llama_stack/providers/remote/inference/tgi/",
|
"^llama_stack/providers/remote/inference/tgi/",
|
||||||
"^llama_stack/providers/remote/inference/together/",
|
"^llama_stack/providers/remote/inference/together/",
|
||||||
"^llama_stack/providers/remote/inference/vllm/",
|
|
||||||
"^llama_stack/providers/remote/safety/bedrock/",
|
"^llama_stack/providers/remote/safety/bedrock/",
|
||||||
"^llama_stack/providers/remote/safety/nvidia/",
|
"^llama_stack/providers/remote/safety/nvidia/",
|
||||||
"^llama_stack/providers/remote/safety/sample/",
|
"^llama_stack/providers/remote/safety/sample/",
|
||||||
|
@ -292,11 +289,6 @@ exclude = [
|
||||||
"^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$",
|
"^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$",
|
||||||
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
|
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
|
||||||
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
|
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
|
||||||
"^llama_stack/strong_typing/auxiliary\\.py$",
|
|
||||||
"^llama_stack/strong_typing/deserializer\\.py$",
|
|
||||||
"^llama_stack/strong_typing/inspection\\.py$",
|
|
||||||
"^llama_stack/strong_typing/schema\\.py$",
|
|
||||||
"^llama_stack/strong_typing/serializer\\.py$",
|
|
||||||
"^llama_stack/templates/dev/dev\\.py$",
|
"^llama_stack/templates/dev/dev\\.py$",
|
||||||
"^llama_stack/templates/groq/groq\\.py$",
|
"^llama_stack/templates/groq/groq\\.py$",
|
||||||
"^llama_stack/templates/sambanova/sambanova\\.py$",
|
"^llama_stack/templates/sambanova/sambanova\\.py$",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue