mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
add async to get_model signature in Protocol
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
6aedfc2201
commit
fcf3b0a835
3 changed files with 7 additions and 7 deletions
|
@ -394,7 +394,7 @@ class EmbeddingsResponse(BaseModel):
|
|||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
def get_model(self, identifier: str) -> Model: ...
|
||||
async def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
|
||||
class TextTruncation(Enum):
|
||||
|
|
|
@ -103,7 +103,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
assert self.model_store is not None
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = self.model_store.get_model(model_id)
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
|
@ -167,7 +167,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
assert self.model_store is not None
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = self.model_store.get_model(model_id)
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
|
@ -287,7 +287,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
assert self.model_store is not None
|
||||
model = self.model_store.get_model(model_id)
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Ollama does not support media for embeddings"
|
||||
|
|
|
@ -249,7 +249,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
assert self.model_store is not None
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = self.model_store.get_model(model_id)
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
|
@ -279,7 +279,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
assert self.model_store is not None
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = self.model_store.get_model(model_id)
|
||||
model = await self.model_store.get_model(model_id)
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
|
@ -397,7 +397,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
assert self.client is not None
|
||||
model = self.model_store.get_model(model_id)
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue