diff --git a/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift index 2c03029e0..069a64fcb 100644 --- a/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +++ b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift @@ -45,7 +45,7 @@ public class LocalInference: Inference { var tokens: [String] = [] let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request)) - var stopReason: Components.Schemas.StopReason? = nil + var stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload? = nil var buffer = "" var ipython = false var echoDropped = false @@ -69,13 +69,13 @@ public class LocalInference: Inference { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( + event_type: .progress, delta: .tool_call(Components.Schemas.ToolCallDelta( - parse_status: Components.Schemas.ToolCallParseStatus.started, + _type: Components.Schemas.ToolCallDelta._typePayload.tool_call, tool_call: .case1(""), - _type: Components.Schemas.ToolCallDelta._typePayload.tool_call + parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.started ) - ), - event_type: .progress + ) ) ) ) @@ -89,9 +89,9 @@ public class LocalInference: Inference { var text = "" if token == "<|eot_id|>" { - stopReason = Components.Schemas.StopReason.end_of_turn + stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_turn } else if token == "<|eom_id|>" { - stopReason = Components.Schemas.StopReason.end_of_message + stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message } else { text = token } @@ -99,14 +99,15 @@ public class LocalInference: Inference { var delta: Components.Schemas.ContentDelta if ipython { delta = .tool_call(Components.Schemas.ToolCallDelta( - parse_status: .in_progress, + _type: .tool_call, tool_call: .case1(text), - _type: .tool_call + parse_status: .in_progress )) } else { delta = .text(Components.Schemas.TextDelta( - text: text, - _type: Components.Schemas.TextDelta._typePayload.text) + _type: Components.Schemas.TextDelta._typePayload.text, + text: text + ) ) } @@ -114,8 +115,8 @@ public class LocalInference: Inference { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( - delta: delta, - event_type: .progress + event_type: .progress, + delta: delta ) ) ) @@ -123,41 +124,41 @@ public class LocalInference: Inference { } if stopReason == nil { - stopReason = Components.Schemas.StopReason.out_of_tokens + stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.out_of_tokens } let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!) // TODO: non-streaming support - let didParseToolCalls = message.tool_calls.count > 0 + let didParseToolCalls = message.tool_calls?.count ?? 0 > 0 if ipython && !didParseToolCalls { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( + event_type: .progress, delta: .tool_call(Components.Schemas.ToolCallDelta( - parse_status: Components.Schemas.ToolCallParseStatus.failed, + _type: Components.Schemas.ToolCallDelta._typePayload.tool_call, tool_call: .case1(""), - _type: Components.Schemas.ToolCallDelta._typePayload.tool_call + parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.failed ) - ), - event_type: .progress + ) ) // TODO: stopReason ) ) } - for toolCall in message.tool_calls { + for toolCall in message.tool_calls! { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( + event_type: .progress, delta: .tool_call(Components.Schemas.ToolCallDelta( - parse_status: Components.Schemas.ToolCallParseStatus.succeeded, + _type: Components.Schemas.ToolCallDelta._typePayload.tool_call, tool_call: Components.Schemas.ToolCallDelta.tool_callPayload.ToolCall(toolCall), - _type: Components.Schemas.ToolCallDelta._typePayload.tool_call + parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.succeeded ) - ), - event_type: .progress + ) ) // TODO: stopReason ) @@ -167,11 +168,12 @@ public class LocalInference: Inference { continuation.yield( Components.Schemas.ChatCompletionResponseStreamChunk( event: Components.Schemas.ChatCompletionResponseEvent( + event_type: .complete, delta: .text(Components.Schemas.TextDelta( - text: "", - _type: Components.Schemas.TextDelta._typePayload.text) - ), - event_type: .complete + _type: Components.Schemas.TextDelta._typePayload.text, + text: "" + ) + ) ) // TODO: stopReason ) diff --git a/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift index 0ccd646ed..c7f0d65a2 100644 --- a/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +++ b/llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift @@ -38,10 +38,10 @@ func encodeMessage(message: Components.Schemas.Message) -> String { switch (message) { case .assistant(let m): - if (m.tool_calls.count > 0) { + if (m.tool_calls?.count ?? 0 > 0) { prompt += "<|python_tag|>" } - default: + default:0 break } @@ -91,7 +91,7 @@ func encodeMessage(message: Components.Schemas.Message) -> String { // for t in m.tool_calls { // _processContent(t.) //} - eom = m.stop_reason == Components.Schemas.StopReason.end_of_message + eom = m.stop_reason == Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message case .system(_): break case .tool(_): @@ -124,8 +124,9 @@ func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws - sysContent += try defaultTemplate.render() messages.append(.system(Components.Schemas.SystemMessage( - content: .case1(sysContent), - role: .system)) + role: .system, + content: .case1(sysContent) + )) ) if request.tools?.isEmpty == false { @@ -134,8 +135,8 @@ func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws - let toolTemplate = try toolGen.gen(customTools: request.tools!) let tools = try toolTemplate.render() messages.append(.user(Components.Schemas.UserMessage( - content: .case1(tools), - role: .user) + role: .user, + content: .case1(tools)) )) } @@ -193,9 +194,9 @@ public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.To result.append( Components.Schemas.ToolCall( - arguments: .init(additionalProperties: props), call_id: UUID().uuidString, - tool_name: .case2(name) // custom_tool + tool_name: .case2(name), // custom_tool + arguments: .init(additionalProperties: props) ) ) } @@ -206,7 +207,7 @@ public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.To } } -func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.StopReason) -> Components.Schemas.CompletionMessage { +func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload) -> Components.Schemas.CompletionMessage { var content = tokens let roles = ["user", "system", "assistant"] @@ -229,8 +230,8 @@ func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.StopR } return Components.Schemas.CompletionMessage( - content: .case1(content), role: .assistant, + content: .case1(content), stop_reason: stopReason, tool_calls: maybeExtractCustomToolCalls(input: content) )