This commit is contained in:
Facundo Santiago 2024-11-04 07:54:31 +00:00
parent 27a0545f5f
commit e247849d1b

View file

@ -9,9 +9,8 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from azure.ai.inference.aio import ChatCompletionsClient as ChatCompletionsClientAsync
from azure.core.credentials import AzureKeyCredential
@ -55,7 +54,7 @@ class AzureAIInferenceAdapter(Inference, ModelsProtocolPrivate):
@property
def client(self) -> ChatCompletionsClientAsync:
if self.config.credential is None:
credential = DefaultAzureCredential()
credential = DefaultAzureCredential()
else:
credential = AzureKeyCredential(self.config.credential)
@ -68,7 +67,7 @@ class AzureAIInferenceAdapter(Inference, ModelsProtocolPrivate):
)
else:
return ChatCompletionsClientAsync(
endpoint=self.config.endpoint,
endpoint=self.config.endpoint,
credential=credential,
user_agent="llama-stack",
)
@ -98,7 +97,6 @@ class AzureAIInferenceAdapter(Inference, ModelsProtocolPrivate):
async def list_models(self) -> List[ModelDef]:
print("Model name: ", self._model_name)
if self._model_name is None:
return [
ModelDef(identifier=model_name, llama_model=azure_model_id)