mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-30 00:59:32 +00:00
Pass 1 for pre-commit fixes
This commit is contained in:
parent
cfc6bdae68
commit
59e1c5f4a0
7 changed files with 119 additions and 109 deletions
|
|
@ -14,7 +14,9 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
|
@ -29,7 +31,6 @@ from llama_stack.apis.inference.inference import (
|
|||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
ResponseFormatType,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.lmstudio._client import LMStudioClient
|
||||
|
|
@ -50,6 +51,18 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
def client(self) -> LMStudioClient:
|
||||
return LMStudioClient(url=self.url)
|
||||
|
||||
async def batch_chat_completion(self, *args, **kwargs):
|
||||
raise NotImplementedError("Batch chat completion not supported by LM Studio Provider")
|
||||
|
||||
async def batch_completion(self, *args, **kwargs):
|
||||
raise NotImplementedError("Batch completion not supported by LM Studio Provider")
|
||||
|
||||
async def openai_chat_completion(self, *args, **kwargs):
|
||||
raise NotImplementedError("OpenAI chat completion not supported by LM Studio Provider")
|
||||
|
||||
async def openai_completion(self, *args, **kwargs):
|
||||
raise NotImplementedError("OpenAI completion not supported by LM Studio Provider")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
@ -71,9 +84,12 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Media content not supported in embedding model"
|
||||
)
|
||||
if self.model_store is None:
|
||||
raise ValueError("ModelStore is not initialized")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = await self.client.get_embedding_model(model.provider_model_id)
|
||||
embeddings = await self.client.embed(embedding_model, contents)
|
||||
string_contents = [item.text if hasattr(item, "text") else str(item) for item in contents]
|
||||
embeddings = await self.client.embed(embedding_model, string_contents)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def chat_completion(
|
||||
|
|
@ -81,26 +97,31 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_choice: Optional[ToolChoice] = None, # Default value changed from ToolChoice.auto to None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat]
|
||||
] = None, # Moved and type changed
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if self.model_store is None:
|
||||
raise ValueError("ModelStore is not initialized")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
llm = await self.client.get_llm(model.provider_model_id)
|
||||
|
||||
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
|
||||
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
|
||||
json_schema = response_format.json_schema if response_format else None
|
||||
|
||||
json_schema_format = response_format if isinstance(response_format, JsonSchemaResponseFormat) else None
|
||||
if response_format is not None and not isinstance(response_format, JsonSchemaResponseFormat):
|
||||
raise ValueError(
|
||||
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
|
||||
)
|
||||
return await self.client.llm_respond(
|
||||
llm=llm,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
json_schema=json_schema,
|
||||
json_schema=json_schema_format,
|
||||
stream=stream,
|
||||
tool_config=tool_config,
|
||||
tools=tools,
|
||||
|
|
@ -115,13 +136,16 @@ class LMStudioInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None, # Skip this for now
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
if self.model_store is None:
|
||||
raise ValueError("ModelStore is not initialized")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
llm = await self.client.get_llm(model.provider_model_id)
|
||||
if content_has_media(content):
|
||||
raise NotImplementedError("Media content not supported in LM Studio Provider")
|
||||
|
||||
if response_format is not None and response_format.type != ResponseFormatType.json_schema.value:
|
||||
raise ValueError(f"Response format type {response_format.type} not supported for LM Studio Provider")
|
||||
json_schema = response_format.json_schema if response_format else None
|
||||
if not isinstance(response_format, JsonSchemaResponseFormat):
|
||||
raise ValueError(
|
||||
f"Response format type {type(response_format).__name__} not supported for LM Studio Provider"
|
||||
)
|
||||
|
||||
return await self.client.llm_completion(llm, content, sampling_params, json_schema, stream)
|
||||
return await self.client.llm_completion(llm, content, sampling_params, response_format, stream)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue