From 33235653ee475fcb28573076b30c7c552050cb79 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Dec 2023 09:39:49 +1300 Subject: [PATCH 1/8] Fix console builds --- OnnxStack.sln | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/OnnxStack.sln b/OnnxStack.sln index be803799..1a23fb3f 100644 --- a/OnnxStack.sln +++ b/OnnxStack.sln @@ -57,14 +57,14 @@ Global {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-DirectML|Any CPU.Build.0 = Release|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-TensorRT|Any CPU.ActiveCfg = Release|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-TensorRT|Any CPU.Build.0 = Release|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug|Any CPU.Build.0 = Debug|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-Cuda|Any CPU.ActiveCfg = Debug|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-Cuda|Any CPU.Build.0 = Debug|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-DirectML|Any CPU.ActiveCfg = Debug|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-DirectML|Any CPU.Build.0 = Debug|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-TensorRT|Any CPU.ActiveCfg = Debug|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-TensorRT|Any CPU.Build.0 = Debug|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug|Any CPU.ActiveCfg = Debug-DirectML|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug|Any CPU.Build.0 = Debug-DirectML|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-Cuda|Any CPU.ActiveCfg = Debug-Cuda|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-Cuda|Any CPU.Build.0 = Debug-Cuda|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-DirectML|Any CPU.ActiveCfg = Debug-DirectML|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-DirectML|Any CPU.Build.0 = Debug-DirectML|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-TensorRT|Any CPU.ActiveCfg = Debug-TensorRT|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-TensorRT|Any CPU.Build.0 = Debug-TensorRT|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release|Any CPU.ActiveCfg = Release|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release|Any CPU.Build.0 = Release|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-Cuda|Any CPU.ActiveCfg = Release|Any CPU From e2c31c2ddd84bb0ecc34e66e29aaea03c367571c Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Dec 2023 12:46:35 +1300 Subject: [PATCH 2/8] Video service for creating and dissecting video --- OnnxStack.Core/Config/OnnxStackConfig.cs | 4 + OnnxStack.Core/OnnxStack.Core.csproj | 2 + OnnxStack.Core/Services/IVideoService.cs | 100 ++++++ OnnxStack.Core/Services/VideoService.cs | 432 +++++++++++++++++++++++ OnnxStack.Core/Video/VideoFrames.cs | 6 + OnnxStack.Core/Video/VideoInfo.cs | 6 + OnnxStack.Core/Video/VideoInput.cs | 65 ++++ OnnxStack.Core/Video/VideoResult.cs | 4 + 8 files changed, 619 insertions(+) create mode 100644 OnnxStack.Core/Services/IVideoService.cs create mode 100644 OnnxStack.Core/Services/VideoService.cs create mode 100644 OnnxStack.Core/Video/VideoFrames.cs create mode 100644 OnnxStack.Core/Video/VideoInfo.cs create mode 100644 OnnxStack.Core/Video/VideoInput.cs create mode 100644 OnnxStack.Core/Video/VideoResult.cs diff --git a/OnnxStack.Core/Config/OnnxStackConfig.cs b/OnnxStack.Core/Config/OnnxStackConfig.cs index fb2c035b..51c66d80 100644 --- a/OnnxStack.Core/Config/OnnxStackConfig.cs +++ b/OnnxStack.Core/Config/OnnxStackConfig.cs @@ -4,6 +4,10 @@ namespace OnnxStack.Core.Config { public class OnnxStackConfig : IConfigSection { + public string TempPath { get; set; } = ".temp"; + public string FFmpegPath { get; set; } = "ffmpeg.exe"; + public string FFprobePath { get; set; } = "ffprobe.exe"; + public void Initialize() { } diff --git a/OnnxStack.Core/OnnxStack.Core.csproj b/OnnxStack.Core/OnnxStack.Core.csproj index 36d37956..2cbabb85 100644 --- a/OnnxStack.Core/OnnxStack.Core.csproj +++ b/OnnxStack.Core/OnnxStack.Core.csproj @@ -35,6 +35,7 @@ + @@ -42,6 +43,7 @@ + diff --git a/OnnxStack.Core/Services/IVideoService.cs b/OnnxStack.Core/Services/IVideoService.cs new file mode 100644 index 00000000..6f5d0678 --- /dev/null +++ b/OnnxStack.Core/Services/IVideoService.cs @@ -0,0 +1,100 @@ +using OnnxStack.Core.Video; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.Core.Services +{ + /// + /// Service with basic handling of video for use in OnnxStack, Frame->Video and Video->Frames + /// + public interface IVideoService + { + /// + /// Gets the video information asynchronous. + /// + /// The video bytes. + /// The cancellation token. + /// + Task GetVideoInfoAsync(byte[] videoBytes, CancellationToken cancellationToken = default); + + /// + /// Gets the video information asynchronous. + /// + /// The video stream. + /// The cancellation token. + /// + Task GetVideoInfoAsync(Stream videoStream, CancellationToken cancellationToken = default); + + /// + /// Gets the video information, Size, FPS, Duration etc. + /// + /// The video input. + /// The cancellation token. + /// + /// No video data found + Task GetVideoInfoAsync(VideoInput videoInput, CancellationToken cancellationToken = default); + + + /// + /// Creates a collection of PNG frames from a video source + /// + /// The video bytes. + /// The video FPS. + /// The cancellation token. + /// + Task CreateFramesAsync(byte[] videoBytes, float videoFPS, CancellationToken cancellationToken = default); + + + /// + /// Creates a collection of PNG frames from a video source + /// + /// The video stream. + /// The video FPS. + /// The cancellation token. + /// + Task CreateFramesAsync(Stream videoStream, float videoFPS, CancellationToken cancellationToken = default); + + + /// + /// Creates a collection of PNG frames from a video source + /// + /// The video input. + /// The video FPS. + /// The cancellation token. + /// + /// VideoTensor not supported + /// No video data found + Task CreateFramesAsync(VideoInput videoInput, float videoFPS, CancellationToken cancellationToken = default); + + + /// + /// Creates and MP4 video from a collection of PNG images. + /// + /// The video frames. + /// The video FPS. + /// The cancellation token. + /// + Task CreateVideoAsync(IEnumerable videoFrames, float videoFPS, CancellationToken cancellationToken = default); + + + /// + /// Creates and MP4 video from a collection of PNG images. + /// + /// The video frames. + /// The cancellation token. + /// + Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default); + + + /// + /// Streams frames as PNG as they are processed from a video source + /// + /// The video bytes. + /// The target FPS. + /// The cancellation token. + /// + IAsyncEnumerable StreamFramesAsync(byte[] videoBytes, float targetFPS, CancellationToken cancellationToken = default); + } +} \ No newline at end of file diff --git a/OnnxStack.Core/Services/VideoService.cs b/OnnxStack.Core/Services/VideoService.cs new file mode 100644 index 00000000..421cbb36 --- /dev/null +++ b/OnnxStack.Core/Services/VideoService.cs @@ -0,0 +1,432 @@ +using FFMpegCore; +using OnnxStack.Core.Config; +using OnnxStack.Core.Video; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.Core.Services +{ + /// + /// Service with basic handling of video for use in OnnxStack, Frame->Video and Video->Frames + /// + public class VideoService : IVideoService + { + private readonly OnnxStackConfig _configuration; + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + public VideoService(OnnxStackConfig configuration) + { + _configuration = configuration; + } + + #region Public Members + + /// + /// Gets the video information, Size, FPS, Duration etc. + /// + /// The video input. + /// The cancellation token. + /// + /// VideoTensor not supported + /// No video data found + public async Task GetVideoInfoAsync(VideoInput videoInput, CancellationToken cancellationToken = default) + { + if (videoInput.VideoBytes is not null) + return await GetVideoInfoAsync(videoInput.VideoBytes, cancellationToken); + if (videoInput.VideoStream is not null) + return await GetVideoInfoAsync(videoInput.VideoStream, cancellationToken); + if (videoInput.VideoTensor is not null) + throw new NotSupportedException("VideoTensor not supported"); + + throw new ArgumentException("No video data found"); + } + + + /// + /// Gets the video information asynchronous. + /// + /// The video stream. + /// The cancellation token. + /// + public async Task GetVideoInfoAsync(Stream videoStream, CancellationToken cancellationToken = default) + { + using (var memoryStream = new MemoryStream()) + { + await memoryStream.CopyToAsync(videoStream, cancellationToken); + return await GetVideoInfoInternalAsync(memoryStream, cancellationToken); + } + } + + + /// + /// Gets the video information asynchronous. + /// + /// The video bytes. + /// The cancellation token. + /// + public async Task GetVideoInfoAsync(byte[] videoBytes, CancellationToken cancellationToken = default) + { + using (var videoStream = new MemoryStream(videoBytes)) + { + return await GetVideoInfoInternalAsync(videoStream, cancellationToken); + } + } + + + /// + /// Creates and MP4 video from a collection of PNG images. + /// + /// The video frames. + /// The cancellation token. + /// + public async Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default) + { + return await CreateVideoInternalAsync(videoFrames.Frames, videoFrames.FPS, cancellationToken); + } + + + /// + /// Creates and MP4 video from a collection of PNG images. + /// + /// The video frames. + /// The video FPS. + /// The cancellation token. + /// + public async Task CreateVideoAsync(IEnumerable videoFrames, float videoFPS, CancellationToken cancellationToken = default) + { + return await CreateVideoInternalAsync(videoFrames, videoFPS, cancellationToken); + } + + + /// + /// Creates a collection of PNG frames from a video source + /// + /// The video input. + /// The video FPS. + /// The cancellation token. + /// + /// VideoTensor not supported + /// No video data found + public async Task CreateFramesAsync(VideoInput videoInput, float videoFPS, CancellationToken cancellationToken = default) + { + + if (videoInput.VideoBytes is not null) + return await CreateFramesAsync(videoInput.VideoBytes, videoFPS, cancellationToken); + if (videoInput.VideoStream is not null) + return await CreateFramesAsync(videoInput.VideoStream, videoFPS, cancellationToken); + if (videoInput.VideoTensor is not null) + throw new NotSupportedException("VideoTensor not supported"); + + throw new ArgumentException("No video data found"); + } + + + /// + /// Creates a collection of PNG frames from a video source + /// + /// The video bytes. + /// The video FPS. + /// The cancellation token. + /// + public async Task CreateFramesAsync(byte[] videoBytes, float videoFPS, CancellationToken cancellationToken = default) + { + return new VideoFrames(videoFPS, await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken)); + } + + + /// + /// Creates a collection of PNG frames from a video source + /// + /// The video stream. + /// The video FPS. + /// The cancellation token. + /// + public async Task CreateFramesAsync(Stream videoStream, float videoFPS, CancellationToken cancellationToken = default) + { + using (var memoryStream = new MemoryStream()) + { + await memoryStream.CopyToAsync(videoStream, cancellationToken).ConfigureAwait(false); + return new VideoFrames(videoFPS, await CreateFramesInternalAsync(memoryStream.ToArray(), videoFPS, cancellationToken).ToListAsync(cancellationToken)); + } + } + + + /// + /// Streams frames as PNG as they are processed from a video source + /// + /// The video bytes. + /// The target FPS. + /// The cancellation token. + /// + public IAsyncEnumerable StreamFramesAsync(byte[] videoBytes, float targetFPS, CancellationToken cancellationToken = default) + { + return CreateFramesInternalAsync(videoBytes, targetFPS, cancellationToken); + } + + #endregion + + #region Private Members + + + /// + /// Gets the video information. + /// + /// The video stream. + /// The cancellation token. + /// + private async Task GetVideoInfoInternalAsync(MemoryStream videoStream, CancellationToken cancellationToken = default) + { + var result = await FFProbe.AnalyseAsync(videoStream, cancellationToken: cancellationToken).ConfigureAwait(false); + return new VideoInfo(result.PrimaryVideoStream.Width, result.PrimaryVideoStream.Height, result.Duration, (int)result.PrimaryVideoStream.FrameRate); + } + + + /// + /// Creates an MP4 video from a collection of PNG frames + /// + /// The image data. + /// The FPS. + /// The cancellation token. + /// + private async Task CreateVideoInternalAsync(IEnumerable imageData, float fps = 15, CancellationToken cancellationToken = default) + { + string tempVideoPath = GetTempFilename(); + try + { + // Analyze first fram to get some details + var frameInfo = await GetVideoInfoAsync(imageData.First()); + var aspectRatio = (double)frameInfo.Width / frameInfo.Height; + using (var videoWriter = CreateWriter(tempVideoPath, fps, aspectRatio)) + { + // Start FFMPEG + videoWriter.Start(); + foreach (var image in imageData) + { + // Write each frame to the input stream of FFMPEG + await videoWriter.StandardInput.BaseStream.WriteAsync(image, cancellationToken); + } + + // Done close stream and wait for app to process + videoWriter.StandardInput.BaseStream.Close(); + await videoWriter.WaitForExitAsync(cancellationToken); + + // Read result from temp file + var videoResult = await File.ReadAllBytesAsync(tempVideoPath, cancellationToken); + + // Analyze the result + var videoInfo = await GetVideoInfoAsync(videoResult); + return new VideoResult(videoResult, videoInfo); + } + } + finally + { + DeleteTempFile(tempVideoPath); + } + } + + + /// + /// Creates a collection of PNG frames from a video source + /// + /// The video data. + /// The FPS. + /// The cancellation token. + /// + /// Invalid PNG header + private async IAsyncEnumerable CreateFramesInternalAsync(byte[] videoData, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string tempVideoPath = GetTempFilename(); + try + { + await File.WriteAllBytesAsync(tempVideoPath, videoData, cancellationToken); + using (var ffmpegProcess = CreateReader(tempVideoPath, fps)) + { + // Start FFMPEG + ffmpegProcess.Start(); + + // FFMPEG output stream + var processOutputStream = ffmpegProcess.StandardOutput.BaseStream; + + // Buffer to hold the current image + var buffer = new byte[20480000]; + + var currentIndex = 0; + while (!cancellationToken.IsCancellationRequested) + { + // Reset the index new PNG + currentIndex = 0; + + // Read the PNG Header + if (await processOutputStream.ReadAsync(buffer, currentIndex, 8, cancellationToken) <= 0) + break; + + currentIndex += 8;// header length + + if (!IsImageHeader(buffer)) + throw new Exception("Invalid PNG header"); + + // loop through each chunk + while (true) + { + // Read the chunk header + await processOutputStream.ReadAsync(buffer, currentIndex, 12, cancellationToken); + + var chunkIndex = currentIndex; + currentIndex += 12; // Chunk header length + + // Get the chunk's content size in bytes from the header we just read + var totalSize = buffer[chunkIndex] << 24 | buffer[chunkIndex + 1] << 16 | buffer[chunkIndex + 2] << 8 | buffer[chunkIndex + 3]; + if (totalSize > 0) + { + var totalRead = 0; + while (totalRead < totalSize) + { + int read = await processOutputStream.ReadAsync(buffer, currentIndex, totalSize - totalRead, cancellationToken); + currentIndex += read; + totalRead += read; + } + continue; + } + + // If the size is 0 and is the end of the image + if (totalSize == 0 && IsImageEnd(buffer, chunkIndex)) + break; + } + + // Return Image stream + using (var imageStream = new MemoryStream(buffer, 0, currentIndex)) + yield return imageStream.ToArray(); + } + + if (cancellationToken.IsCancellationRequested) + ffmpegProcess.Kill(); + } + } + finally + { + DeleteTempFile(tempVideoPath); + } + } + + + /// + /// Gets the temporary filename. + /// + /// + private string GetTempFilename() + { + if (!Directory.Exists(_configuration.TempPath)) + Directory.CreateDirectory(_configuration.TempPath); + + return Path.Combine(_configuration.TempPath, $"{Path.GetFileNameWithoutExtension(Path.GetRandomFileName())}.mp4"); + } + + + /// + /// Deletes the temporary file. + /// + /// The filename. + private void DeleteTempFile(string filename) + { + try + { + if (File.Exists(filename)) + File.Delete(filename); + } + catch (Exception) + { + // File in use, Log + } + } + + + /// + /// Creates FFMPEG video reader process. + /// + /// The input file. + /// The FPS. + /// + private Process CreateReader(string inputFile, float fps) + { + var ffmpegProcess = new Process(); + ffmpegProcess.StartInfo.FileName = _configuration.FFmpegPath; + ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -i \"{inputFile}\" -c:v png -r {fps} -f image2pipe -"; + ffmpegProcess.StartInfo.RedirectStandardOutput = true; + ffmpegProcess.StartInfo.UseShellExecute = false; + ffmpegProcess.StartInfo.CreateNoWindow = true; + return ffmpegProcess; + } + + + /// + /// Creates FFMPEG video writer process. + /// + /// The output file. + /// The FPS. + /// The aspect ratio. + /// + private Process CreateWriter(string outputFile, float fps, double aspectRatio) + { + var ffmpegProcess = new Process(); + ffmpegProcess.StartInfo.FileName = _configuration.FFmpegPath; + ffmpegProcess.StartInfo.Arguments = $"-hide_banner -loglevel error -framerate {fps:F4} -i - -c:v libx264 -movflags +faststart -vf format=yuv420p -aspect {aspectRatio} {outputFile}"; + ffmpegProcess.StartInfo.RedirectStandardInput = true; + ffmpegProcess.StartInfo.UseShellExecute = false; + ffmpegProcess.StartInfo.CreateNoWindow = true; + return ffmpegProcess; + } + + + /// + /// Determines whether we are at the start of a PNG image in the specified buffer. + /// + /// The buffer. + /// The offset. + /// + /// true if the start of a PNG image sequence is detectedfalse. + /// + private static bool IsImageHeader(byte[] buffer) + { + // PNG Header http://www.libpng.org/pub/png/spec/1.2/PNG-Structure.html#PNG-file-signature + if (buffer[0] != 0x89 + || buffer[1] != 0x50 + || buffer[2] != 0x4E + || buffer[3] != 0x47 + || buffer[4] != 0x0D + || buffer[5] != 0x0A + || buffer[6] != 0x1A + || buffer[7] != 0x0A) + return false; + + return true; + } + + + /// + /// Determines whether we are at the end of a PNG image in the specified buffer. + /// + /// The buffer. + /// The offset. + /// + /// true if the end of a PNG image sequence is detectedfalse. + /// + private static bool IsImageEnd(byte[] buffer, int offset) + { + return buffer[offset + 4] == 0x49 // I + && buffer[offset + 5] == 0x45 // E + && buffer[offset + 6] == 0x4E // N + && buffer[offset + 7] == 0x44; // D + } + } + + #endregion +} diff --git a/OnnxStack.Core/Video/VideoFrames.cs b/OnnxStack.Core/Video/VideoFrames.cs new file mode 100644 index 00000000..4044c845 --- /dev/null +++ b/OnnxStack.Core/Video/VideoFrames.cs @@ -0,0 +1,6 @@ +using System.Collections.Generic; + +namespace OnnxStack.Core.Video +{ + public record VideoFrames(float FPS, List Frames); +} diff --git a/OnnxStack.Core/Video/VideoInfo.cs b/OnnxStack.Core/Video/VideoInfo.cs new file mode 100644 index 00000000..c589a46e --- /dev/null +++ b/OnnxStack.Core/Video/VideoInfo.cs @@ -0,0 +1,6 @@ +using System; + +namespace OnnxStack.Core.Video +{ + public record VideoInfo(int Width, int Height, TimeSpan Duration, int Fps); +} diff --git a/OnnxStack.Core/Video/VideoInput.cs b/OnnxStack.Core/Video/VideoInput.cs new file mode 100644 index 00000000..ef2b268e --- /dev/null +++ b/OnnxStack.Core/Video/VideoInput.cs @@ -0,0 +1,65 @@ +using Microsoft.ML.OnnxRuntime.Tensors; +using System.IO; +using System.Text.Json.Serialization; + +namespace OnnxStack.Core.Video +{ + public class VideoInput + { + /// + /// Initializes a new instance of the class. + /// + public VideoInput() { } + + /// + /// Initializes a new instance of the class. + /// + /// The video bytes. + public VideoInput(byte[] videoBytes) => VideoBytes = videoBytes; + + /// + /// Initializes a new instance of the class. + /// + /// The video stream. + public VideoInput(Stream videoStream) => VideoStream = videoStream; + + /// + /// Initializes a new instance of the class. + /// + /// The video tensor. + public VideoInput(DenseTensor videoTensor) => VideoTensor = videoTensor; + + + /// + /// Gets the video bytes. + /// + [JsonIgnore] + public byte[] VideoBytes { get; set; } + + + /// + /// Gets the video stream. + /// + [JsonIgnore] + public Stream VideoStream { get; set; } + + + /// + /// Gets the video tensor. + /// + [JsonIgnore] + public DenseTensor VideoTensor { get; set; } + + + /// + /// Gets a value indicating whether this instance has video. + /// + /// + /// true if this instance has video; otherwise, false. + /// + [JsonIgnore] + public bool HasVideo => VideoBytes != null + || VideoStream != null + || VideoTensor != null; + } +} diff --git a/OnnxStack.Core/Video/VideoResult.cs b/OnnxStack.Core/Video/VideoResult.cs new file mode 100644 index 00000000..acb19cdd --- /dev/null +++ b/OnnxStack.Core/Video/VideoResult.cs @@ -0,0 +1,4 @@ +namespace OnnxStack.Core.Video +{ + public record VideoResult(byte[] Data, VideoInfo Info); +} From 47836c1ae844f93a9e57dddec4e7afc0a795e7f7 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Dec 2023 15:19:51 +1300 Subject: [PATCH 3/8] StableDiffusionService initial video support --- OnnxStack.Core/Services/IVideoService.cs | 4 +- OnnxStack.Core/Services/VideoService.cs | 19 ++++--- OnnxStack.Core/Video/VideoFrames.cs | 2 +- OnnxStack.Core/Video/VideoInfo.cs | 2 +- OnnxStack.Core/Video/VideoOutput.cs | 4 ++ OnnxStack.Core/Video/VideoResult.cs | 4 -- .../Config/PromptOptions.cs | 4 ++ .../Config/SchedulerOptions.cs | 2 + .../Enums/DiffuserType.cs | 5 +- .../Services/StableDiffusionService.cs | 50 +++++++++++++++---- 10 files changed, 70 insertions(+), 26 deletions(-) create mode 100644 OnnxStack.Core/Video/VideoOutput.cs delete mode 100644 OnnxStack.Core/Video/VideoResult.cs diff --git a/OnnxStack.Core/Services/IVideoService.cs b/OnnxStack.Core/Services/IVideoService.cs index 6f5d0678..465aa501 100644 --- a/OnnxStack.Core/Services/IVideoService.cs +++ b/OnnxStack.Core/Services/IVideoService.cs @@ -76,7 +76,7 @@ public interface IVideoService /// The video FPS. /// The cancellation token. /// - Task CreateVideoAsync(IEnumerable videoFrames, float videoFPS, CancellationToken cancellationToken = default); + Task CreateVideoAsync(IEnumerable videoFrames, float videoFPS, CancellationToken cancellationToken = default); /// @@ -85,7 +85,7 @@ public interface IVideoService /// The video frames. /// The cancellation token. /// - Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default); + Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default); /// diff --git a/OnnxStack.Core/Services/VideoService.cs b/OnnxStack.Core/Services/VideoService.cs index 421cbb36..08f7a2eb 100644 --- a/OnnxStack.Core/Services/VideoService.cs +++ b/OnnxStack.Core/Services/VideoService.cs @@ -88,9 +88,9 @@ public async Task GetVideoInfoAsync(byte[] videoBytes, CancellationTo /// The video frames. /// The cancellation token. /// - public async Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default) + public async Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default) { - return await CreateVideoInternalAsync(videoFrames.Frames, videoFrames.FPS, cancellationToken); + return await CreateVideoInternalAsync(videoFrames.Frames, videoFrames.Info.FPS, cancellationToken); } @@ -101,7 +101,7 @@ public async Task CreateVideoAsync(VideoFrames videoFrames, Cancell /// The video FPS. /// The cancellation token. /// - public async Task CreateVideoAsync(IEnumerable videoFrames, float videoFPS, CancellationToken cancellationToken = default) + public async Task CreateVideoAsync(IEnumerable videoFrames, float videoFPS, CancellationToken cancellationToken = default) { return await CreateVideoInternalAsync(videoFrames, videoFPS, cancellationToken); } @@ -139,7 +139,9 @@ public async Task CreateFramesAsync(VideoInput videoInput, float vi /// public async Task CreateFramesAsync(byte[] videoBytes, float videoFPS, CancellationToken cancellationToken = default) { - return new VideoFrames(videoFPS, await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken)); + var videoInfo = await GetVideoInfoAsync(videoBytes, cancellationToken); + var videoFrames = await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken); + return new VideoFrames(videoInfo, videoFrames); } @@ -155,7 +157,10 @@ public async Task CreateFramesAsync(Stream videoStream, float video using (var memoryStream = new MemoryStream()) { await memoryStream.CopyToAsync(videoStream, cancellationToken).ConfigureAwait(false); - return new VideoFrames(videoFPS, await CreateFramesInternalAsync(memoryStream.ToArray(), videoFPS, cancellationToken).ToListAsync(cancellationToken)); + var videoBytes = memoryStream.ToArray(); + var videoInfo = await GetVideoInfoAsync(videoBytes, cancellationToken); + var videoFrames = await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken); + return new VideoFrames(videoInfo, videoFrames); } } @@ -197,7 +202,7 @@ private async Task GetVideoInfoInternalAsync(MemoryStream videoStream /// The FPS. /// The cancellation token. /// - private async Task CreateVideoInternalAsync(IEnumerable imageData, float fps = 15, CancellationToken cancellationToken = default) + private async Task CreateVideoInternalAsync(IEnumerable imageData, float fps = 15, CancellationToken cancellationToken = default) { string tempVideoPath = GetTempFilename(); try @@ -224,7 +229,7 @@ private async Task CreateVideoInternalAsync(IEnumerable ima // Analyze the result var videoInfo = await GetVideoInfoAsync(videoResult); - return new VideoResult(videoResult, videoInfo); + return new VideoOutput(videoResult, videoInfo); } } finally diff --git a/OnnxStack.Core/Video/VideoFrames.cs b/OnnxStack.Core/Video/VideoFrames.cs index 4044c845..3c4d35bc 100644 --- a/OnnxStack.Core/Video/VideoFrames.cs +++ b/OnnxStack.Core/Video/VideoFrames.cs @@ -2,5 +2,5 @@ namespace OnnxStack.Core.Video { - public record VideoFrames(float FPS, List Frames); + public record VideoFrames(VideoInfo Info, IReadOnlyList Frames); } diff --git a/OnnxStack.Core/Video/VideoInfo.cs b/OnnxStack.Core/Video/VideoInfo.cs index c589a46e..d3a91f03 100644 --- a/OnnxStack.Core/Video/VideoInfo.cs +++ b/OnnxStack.Core/Video/VideoInfo.cs @@ -2,5 +2,5 @@ namespace OnnxStack.Core.Video { - public record VideoInfo(int Width, int Height, TimeSpan Duration, int Fps); + public record VideoInfo(int Width, int Height, TimeSpan Duration, int FPS); } diff --git a/OnnxStack.Core/Video/VideoOutput.cs b/OnnxStack.Core/Video/VideoOutput.cs new file mode 100644 index 00000000..db4195b3 --- /dev/null +++ b/OnnxStack.Core/Video/VideoOutput.cs @@ -0,0 +1,4 @@ +namespace OnnxStack.Core.Video +{ + public record VideoOutput(byte[] Data, VideoInfo Info); +} diff --git a/OnnxStack.Core/Video/VideoResult.cs b/OnnxStack.Core/Video/VideoResult.cs deleted file mode 100644 index acb19cdd..00000000 --- a/OnnxStack.Core/Video/VideoResult.cs +++ /dev/null @@ -1,4 +0,0 @@ -namespace OnnxStack.Core.Video -{ - public record VideoResult(byte[] Data, VideoInfo Info); -} diff --git a/OnnxStack.StableDiffusion/Config/PromptOptions.cs b/OnnxStack.StableDiffusion/Config/PromptOptions.cs index f6713da5..5c0ab60b 100644 --- a/OnnxStack.StableDiffusion/Config/PromptOptions.cs +++ b/OnnxStack.StableDiffusion/Config/PromptOptions.cs @@ -1,4 +1,5 @@ using OnnxStack.Core.Image; +using OnnxStack.Core.Video; using OnnxStack.StableDiffusion.Enums; using System.ComponentModel.DataAnnotations; @@ -19,6 +20,9 @@ public class PromptOptions public InputImage InputImageMask { get; set; } + public VideoFrames InputVideo { get; set; } + + public bool HasInputVideo => InputVideo?.Frames?.Count > 0; public bool HasInputImage => InputImage?.HasImage ?? false; public bool HasInputImageMask => InputImageMask?.HasImage ?? false; } diff --git a/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs b/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs index 3b258a0c..50e33555 100644 --- a/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs +++ b/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs @@ -84,6 +84,8 @@ public record SchedulerOptions public float AestheticScore { get; set; } = 6f; public float AestheticNegativeScore { get; set; } = 2.5f; + public float VideoFPS { get; set; } + public bool IsKarrasScheduler { get diff --git a/OnnxStack.StableDiffusion/Enums/DiffuserType.cs b/OnnxStack.StableDiffusion/Enums/DiffuserType.cs index 96a127f1..dd4c997a 100644 --- a/OnnxStack.StableDiffusion/Enums/DiffuserType.cs +++ b/OnnxStack.StableDiffusion/Enums/DiffuserType.cs @@ -17,6 +17,9 @@ public enum DiffuserType ImageInpaintLegacy = 3, [Description("Image To Animation")] - ImageToAnimation = 4 + ImageToAnimation = 4, + + [Description("Video To Video")] + VideoToVideo = 5 } } diff --git a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs index 24302efe..91e9b256 100644 --- a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs +++ b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs @@ -1,6 +1,5 @@ using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core; -using OnnxStack.Core.Config; using OnnxStack.Core.Services; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; @@ -26,6 +25,7 @@ namespace OnnxStack.StableDiffusion.Services /// public sealed class StableDiffusionService : IStableDiffusionService { + private readonly IVideoService _videoService; private readonly IOnnxModelService _modelService; private readonly StableDiffusionConfig _configuration; private readonly ConcurrentDictionary _pipelines; @@ -34,10 +34,11 @@ public sealed class StableDiffusionService : IStableDiffusionService /// Initializes a new instance of the class. /// /// The scheduler service. - public StableDiffusionService(StableDiffusionConfig configuration, IOnnxModelService onnxModelService, IEnumerable pipelines) + public StableDiffusionService(StableDiffusionConfig configuration, IOnnxModelService onnxModelService, IVideoService videoService, IEnumerable pipelines) { _configuration = configuration; _modelService = onnxModelService; + _videoService = videoService; _pipelines = pipelines.ToConcurrentDictionary(k => k.PipelineType, k => k); } @@ -115,9 +116,11 @@ public async Task> GenerateAsImageAsync(StableDiffusionModelSet mo /// The diffusion result as public async Task GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { - return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken) - .ContinueWith(t => t.Result.ToImageBytes(), cancellationToken) - .ConfigureAwait(false); + var generateResult = await GenerateAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); + if (!prompt.HasInputVideo) + return generateResult.ToImageBytes(); + + return await GetVideoResultAsBytesAsync(options, generateResult, cancellationToken).ConfigureAwait(false); } @@ -131,9 +134,11 @@ public async Task GenerateAsBytesAsync(StableDiffusionModelSet model, Pr /// The diffusion result as public async Task GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { - return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken) - .ContinueWith(t => t.Result.ToImageStream(), cancellationToken) - .ConfigureAwait(false); + var generateResult = await GenerateAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); + if (!prompt.HasInputVideo) + return generateResult.ToImageStream(); + + return await GetVideoResultAsStreamAsync(options, generateResult, cancellationToken).ConfigureAwait(false); } @@ -183,7 +188,12 @@ public async IAsyncEnumerable> GenerateBatchAsImageAsync(StableDif public async IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) - yield return result.ImageResult.ToImageBytes(); + { + if (!promptOptions.HasInputVideo) + yield return result.ImageResult.ToImageBytes(); + + yield return await GetVideoResultAsBytesAsync(schedulerOptions, result.ImageResult, cancellationToken).ConfigureAwait(false); + } } @@ -200,7 +210,12 @@ public async IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionM public async IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) - yield return result.ImageResult.ToImageStream(); + { + if (!promptOptions.HasInputVideo) + yield return result.ImageResult.ToImageStream(); + + yield return await GetVideoResultAsStreamAsync(schedulerOptions, result.ImageResult, cancellationToken).ConfigureAwait(false); + } } @@ -237,6 +252,21 @@ private IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet return diffuser.DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progress, cancellationToken); } + private async Task GetVideoResultAsBytesAsync(SchedulerOptions options, DenseTensor tensorResult, CancellationToken cancellationToken = default) + { + var frameTensors = tensorResult + .Split(tensorResult.Dimensions[0]) + .Select(x => x.ToImageBytes()); + + var videoResult = await _videoService.CreateVideoAsync(frameTensors, options.VideoFPS, cancellationToken); + return videoResult.Data; + } + + private async Task GetVideoResultAsStreamAsync(SchedulerOptions options, DenseTensor tensorResult, CancellationToken cancellationToken = default) + { + return new MemoryStream(await GetVideoResultAsBytesAsync(options, tensorResult, cancellationToken)); + } + } } From 8c00ba9c8a128a966b0e10581b9fde7f47614472 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Dec 2023 15:20:28 +1300 Subject: [PATCH 4/8] LCM VideoToVideo diffuser --- .../LatentConsistency/VideoDiffuser.cs | 170 ++++++++++++++++++ OnnxStack.StableDiffusion/Registration.cs | 4 + 2 files changed, 174 insertions(+) create mode 100644 OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs new file mode 100644 index 00000000..355fd209 --- /dev/null +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs @@ -0,0 +1,170 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.Core.Model; +using OnnxStack.Core.Services; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Helpers; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency +{ + public sealed class VideoDiffuser : LatentConsistencyDiffuser + { + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The onnx model service. + public VideoDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger logger) + : base(onnxModelService, promptService, logger) { } + + + /// + /// Gets the type of the diffuser. + /// + public override DiffuserType DiffuserType => DiffuserType.VideoToVideo; + + + /// + /// Runs the scheduler steps. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The prompt embeddings. + /// if set to true [perform guidance]. + /// The progress callback. + /// The cancellation token. + /// + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + { + DenseTensor resultTensor = null; + foreach (var videoFrame in promptOptions.InputVideo.Frames) + { + // Get Scheduler + using (var scheduler = GetScheduler(schedulerOptions)) + { + // Get timesteps + var timesteps = GetTimesteps(schedulerOptions, scheduler); + + // Create latent sample + var latents = await PrepareFrameLatentsAsync(modelOptions, videoFrame, schedulerOptions, scheduler, timesteps); + + // Get Guidance Scale Embedding + var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale); + + // Denoised result + DenseTensor denoised = null; + + // Get Model metadata + var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet); + + // Loop though the timesteps + var step = 0; + foreach (var timestep in timesteps) + { + step++; + var stepTime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Create input tensor. + var inputTensor = scheduler.ScaleInput(latents, timestep); + var timestepTensor = CreateTimestepTensor(timestep); + + var outputChannels = 1; + var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(inputTensor); + inferenceParameters.AddInputTensor(timestepTensor); + inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); + inferenceParameters.AddInputTensor(guidanceEmbeddings); + inferenceParameters.AddOutputBuffer(outputDimension); + + var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters); + using (var result = results.First()) + { + var noisePred = result.ToDenseTensor(); + + // Scheduler Step + var schedulerResult = scheduler.Step(noisePred, timestep, latents); + + latents = schedulerResult.Result; + denoised = schedulerResult.SampleData; + } + } + + progressCallback?.Invoke(step, timesteps.Count); + _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); + } + + // Decode Latents + var frameResultTensor = await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, denoised); + resultTensor = resultTensor is null + ? frameResultTensor + : resultTensor.Concatenate(frameResultTensor); + } + } + return resultTensor; + } + + + /// + /// Gets the timesteps. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// + protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) + { + var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps); + var start = Math.Max(options.InferenceSteps - inittimestep, 0); + return scheduler.Timesteps.Skip(start).ToList(); + } + + + /// + /// Prepares the latents for inference. + /// + /// The prompt. + /// The options. + /// The scheduler. + /// + private async Task> PrepareFrameLatentsAsync(StableDiffusionModelSet model, byte[] videoFrame, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + var imageTensor = ImageHelpers.TensorFromBytes(videoFrame, new[] { 1, 3, options.Height, options.Width }); + + //TODO: Model Config, Channels + var outputDimension = options.GetScaledDimension(); + var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder); + using (var inferenceParameters = new OnnxInferenceParameters(metadata)) + { + inferenceParameters.AddInputTensor(imageTensor); + inferenceParameters.AddOutputBuffer(outputDimension); + + var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters); + using (var result = results.First()) + { + var outputResult = result.ToDenseTensor(); + var scaledSample = outputResult.MultiplyBy(model.ScaleFactor); + return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps); + } + } + } + + protected override Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) + { + throw new NotImplementedException(); + } + } +} diff --git a/OnnxStack.StableDiffusion/Registration.cs b/OnnxStack.StableDiffusion/Registration.cs index 16d0f94e..17692662 100644 --- a/OnnxStack.StableDiffusion/Registration.cs +++ b/OnnxStack.StableDiffusion/Registration.cs @@ -1,8 +1,10 @@ using Microsoft.Extensions.DependencyInjection; using OnnxStack.Core.Config; +using OnnxStack.Core.Services; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Diffusers; +using OnnxStack.StableDiffusion.Diffusers.LatentConsistency; using OnnxStack.StableDiffusion.Pipelines; using OnnxStack.StableDiffusion.Services; using SixLabors.ImageSharp; @@ -44,6 +46,7 @@ private static void RegisterServices(this IServiceCollection serviceCollection) ConfigureLibraries(); // Services + serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); @@ -69,6 +72,7 @@ private static void RegisterServices(this IServiceCollection serviceCollection) serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); //LatentConsistencyXL serviceCollection.AddSingleton(); From f10c73e696e917367491fc845f7721e38707e49a Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 31 Dec 2023 16:31:07 +1300 Subject: [PATCH 5/8] Add current Latent/Image to progress callback --- .../Examples/StableDiffusionBatch.cs | 7 +-- .../Common/IStableDiffusionService.cs | 16 +++---- .../Diffusers/DiffuserBase.cs | 44 +++++++++++++++++-- .../Diffusers/IDiffuser.cs | 4 +- .../Diffusers/InstaFlow/InstaFlowDiffuser.cs | 5 ++- .../InpaintLegacyDiffuser.cs | 5 ++- .../LatentConsistencyDiffuser.cs | 10 ++--- .../LatentConsistency/VideoDiffuser.cs | 16 +++++-- .../InpaintLegacyDiffuser.cs | 5 ++- .../LatentConsistencyXLDiffuser.cs | 4 +- .../StableDiffusion/InpaintDiffuser.cs | 5 ++- .../StableDiffusion/InpaintLegacyDiffuser.cs | 5 ++- .../StableDiffusionDiffuser.cs | 5 ++- .../InpaintLegacyDiffuser.cs | 5 ++- .../StableDiffusionXLDiffuser.cs | 5 ++- .../Models/DiffusionProgress.cs | 16 +++++++ .../Services/StableDiffusionService.cs | 20 ++++----- OnnxStack.UI/Views/ImageInpaintView.xaml.cs | 11 ++--- OnnxStack.UI/Views/ImageToImageView.xaml.cs | 13 +++--- OnnxStack.UI/Views/TextToImageView.xaml.cs | 11 ++--- 20 files changed, 143 insertions(+), 69 deletions(-) create mode 100644 OnnxStack.StableDiffusion/Models/DiffusionProgress.cs diff --git a/OnnxStack.Console/Examples/StableDiffusionBatch.cs b/OnnxStack.Console/Examples/StableDiffusionBatch.cs index a9cd7af2..fecf4b80 100644 --- a/OnnxStack.Console/Examples/StableDiffusionBatch.cs +++ b/OnnxStack.Console/Examples/StableDiffusionBatch.cs @@ -2,6 +2,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; namespace OnnxStack.Console.Runner @@ -58,10 +59,10 @@ public async Task RunAsync() await _stableDiffusionService.LoadModelAsync(model); var batchIndex = 0; - var callback = (int batch, int batchCount, int step, int steps) => + var callback = (DiffusionProgress progress) => { - batchIndex = batch; - OutputHelpers.WriteConsole($"Image: {batch}/{batchCount} - Step: {step}/{steps}", ConsoleColor.Cyan); + batchIndex = progress.ProgressValue; + OutputHelpers.WriteConsole($"Image: {progress.ProgressValue}/{progress.ProgressMax} - Step: {progress.SubProgressValue}/{progress.SubProgressMax}", ConsoleColor.Cyan); }; await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback)) diff --git a/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs b/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs index ce7edcb4..30a10de5 100644 --- a/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs +++ b/OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs @@ -45,7 +45,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates the StableDiffusion image using the prompt and options provided. @@ -55,7 +55,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates the StableDiffusion image using the prompt and options provided. @@ -65,7 +65,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates the StableDiffusion image using the prompt and options provided. @@ -75,7 +75,7 @@ public interface IStableDiffusionService /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - Task GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); + Task GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -87,7 +87,7 @@ public interface IStableDiffusionService /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable GenerateBatchAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable GenerateBatchAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -99,7 +99,7 @@ public interface IStableDiffusionService /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable> GenerateBatchAsImageAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable> GenerateBatchAsImageAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -111,7 +111,7 @@ public interface IStableDiffusionService /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); /// /// Generates a batch of StableDiffusion image using the prompt and options provided. @@ -123,6 +123,6 @@ public interface IStableDiffusionService /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs index 763076b7..354d2a4c 100644 --- a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs +++ b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs @@ -88,7 +88,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer /// The progress callback. /// The cancellation token. /// - protected abstract Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default); + protected abstract Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default); /// @@ -99,7 +99,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer /// The progress. /// The cancellation token. /// - public virtual async Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public virtual async Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { // Create random seed if none was set schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); @@ -133,7 +133,7 @@ public virtual async Task> DiffuseAsync(StableDiffusionModelS /// The cancellation token. /// /// - public virtual async IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public virtual async IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Create random seed if none was set schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next(); @@ -152,7 +152,11 @@ public virtual async IAsyncEnumerable DiffuseBatchAsync(StableDiffu var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions); var batchIndex = 1; - var schedulerCallback = (int step, int steps) => progressCallback?.Invoke(batchIndex, batchSchedulerOptions.Count, step, steps); + var schedulerCallback = (DiffusionProgress progress) => progressCallback?.Invoke(new DiffusionProgress(batchIndex, batchSchedulerOptions.Count, progress.ProgressTensor) + { + SubProgressMax = progress.ProgressMax, + SubProgressValue = progress.ProgressValue, + }); foreach (var batchSchedulerOption in batchSchedulerOptions) { var diffuseTime = _logger?.LogBegin("Diffuse starting..."); @@ -251,5 +255,37 @@ protected static IReadOnlyList CreateInputParameters(params Name { return parameters.ToList(); } + + + /// + /// Reports the progress. + /// + /// The progress callback. + /// The progress. + /// The progress maximum. + /// The output. + protected void ReportProgress(Action progressCallback, int progress, int progressMax, DenseTensor output) + { + progressCallback?.Invoke(new DiffusionProgress(progress, progressMax, output)); + } + + + /// + /// Reports the progress. + /// + /// The progress callback. + /// The progress. + /// The progress maximum. + /// The sub progress. + /// The sub progress maximum. + /// The output. + protected void ReportProgress(Action progressCallback, int progress, int progressMax, int subProgress, int subProgressMax, DenseTensor output) + { + progressCallback?.Invoke(new DiffusionProgress(progress, progressMax, output) + { + SubProgressMax = subProgressMax, + SubProgressValue = subProgress, + }); + } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs index 00e60c58..fee6958d 100644 --- a/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs @@ -33,7 +33,7 @@ public interface IDiffuser /// The progress callback. /// The cancellation token. /// - Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default); /// @@ -46,6 +46,6 @@ public interface IDiffuser /// The progress callback. /// The cancellation token. /// - IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); + IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default); } } diff --git a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs index 6759b4e0..f8264670 100644 --- a/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs @@ -8,6 +8,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using OnnxStack.StableDiffusion.Schedulers.InstaFlow; using System; using System.Diagnostics; @@ -45,7 +46,7 @@ public InstaFlowDiffuser(IOnnxModelService onnxModelService, IPromptService prom /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) @@ -102,7 +103,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs index b3f92cd9..fa114deb 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs @@ -8,6 +8,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.Processing; using System; @@ -65,7 +66,7 @@ protected override IReadOnlyList GetTimesteps(SchedulerOptions options, ISc /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { using (var scheduler = GetScheduler(schedulerOptions)) { @@ -138,7 +139,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs index 1d1a5819..76e64054 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs @@ -46,7 +46,7 @@ public LatentConsistencyDiffuser(IOnnxModelService onnxModelService, IPromptServ /// /// The cancellation token. /// - public override Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public override Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { // LCM does not support negative prompting promptOptions.NegativePrompt = string.Empty; @@ -64,7 +64,7 @@ public override Task> DiffuseAsync(StableDiffusionModelSet mo /// The progress callback. /// The cancellation token. /// - public override IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public override IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { // LCM does not support negative prompting promptOptions.NegativePrompt = string.Empty; @@ -88,7 +88,7 @@ protected override bool ShouldPerformGuidance(SchedulerOptions schedulerOptions) /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) @@ -143,7 +143,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } @@ -186,7 +186,7 @@ protected DenseTensor GetGuidanceScaleEmbedding(float guidance, int embed var embSin = emb.Select(MathF.Sin); var embCos = emb.Select(MathF.Cos); var guidanceEmbedding = embSin.Concat(embCos).ToArray(); - return new DenseTensor(guidanceEmbedding, new[] { 1, embeddingDim }); + return new DenseTensor(guidanceEmbedding, new[] { 1, embeddingDim }); } } } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs index 355fd209..5b810497 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs @@ -8,6 +8,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using System; using System.Collections.Generic; using System.Diagnostics; @@ -45,12 +46,15 @@ public VideoDiffuser(IOnnxModelService onnxModelService, IPromptService promptSe /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { + var frameIndex = 0; DenseTensor resultTensor = null; - foreach (var videoFrame in promptOptions.InputVideo.Frames) + var videoFrames = promptOptions.InputVideo.Frames; + foreach (var videoFrame in videoFrames) { // Get Scheduler + frameIndex++; using (var scheduler = GetScheduler(schedulerOptions)) { // Get timesteps @@ -103,12 +107,18 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + // Step Progress + ReportProgress(progressCallback, frameIndex, videoFrames.Count, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } // Decode Latents var frameResultTensor = await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, denoised); + + // Frame Progress + ReportProgress(progressCallback, frameIndex, videoFrames.Count, step, timesteps.Count, frameResultTensor); + + // Concatenate tensor frame resultTensor = resultTensor is null ? frameResultTensor : resultTensor.Concatenate(frameResultTensor); diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs index 6a66f947..ee32bf35 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/InpaintLegacyDiffuser.cs @@ -8,6 +8,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.Processing; using System; @@ -49,7 +50,7 @@ public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { using (var scheduler = GetScheduler(schedulerOptions)) { @@ -119,7 +120,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs index 8b91e9c6..a4a4df14 100644 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs @@ -35,7 +35,7 @@ protected LatentConsistencyXLDiffuser(IOnnxModelService onnxModelService, IPromp /// /// The cancellation token. /// - public override Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public override Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { // LCM does not support negative prompting promptOptions.NegativePrompt = string.Empty; @@ -53,7 +53,7 @@ public override Task> DiffuseAsync(StableDiffusionModelSet mo /// The progress callback. /// The cancellation token. /// - public override IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public override IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { // LCM does not support negative prompting promptOptions.NegativePrompt = string.Empty; diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs index 1e509a51..d3189365 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs @@ -8,6 +8,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.Processing; using System; @@ -51,7 +52,7 @@ public InpaintDiffuser(IOnnxModelService onnxModelService, IPromptService prompt /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) @@ -108,7 +109,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs index 6fdd5473..922fa666 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs @@ -9,6 +9,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.Processing; using System; @@ -50,7 +51,7 @@ public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { using (var scheduler = GetScheduler(schedulerOptions)) { @@ -114,7 +115,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs index 60f9c913..b1bb5fe5 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs @@ -9,6 +9,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using OnnxStack.StableDiffusion.Schedulers.StableDiffusion; using System; using System.Collections.Generic; @@ -48,7 +49,7 @@ public StableDiffusionDiffuser(IOnnxModelService onnxModelService, IPromptServic /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) @@ -98,7 +99,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs index f8b5f5c7..c373dbe6 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/InpaintLegacyDiffuser.cs @@ -8,6 +8,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using SixLabors.ImageSharp; using SixLabors.ImageSharp.Processing; using System; @@ -49,7 +50,7 @@ public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { using (var scheduler = GetScheduler(schedulerOptions)) { @@ -119,7 +120,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs index 72a3dad5..70421c1a 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs @@ -8,6 +8,7 @@ using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; +using OnnxStack.StableDiffusion.Models; using OnnxStack.StableDiffusion.Schedulers.StableDiffusion; using System; using System.Diagnostics; @@ -45,7 +46,7 @@ public StableDiffusionXLDiffuser(IOnnxModelService onnxModelService, IPromptServ /// The progress callback. /// The cancellation token. /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) + protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) { // Get Scheduler using (var scheduler = GetScheduler(schedulerOptions)) @@ -101,7 +102,7 @@ protected override async Task> SchedulerStepAsync(StableDiffu } } - progressCallback?.Invoke(step, timesteps.Count); + ReportProgress(progressCallback, step, timesteps.Count, latents); _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); } diff --git a/OnnxStack.StableDiffusion/Models/DiffusionProgress.cs b/OnnxStack.StableDiffusion/Models/DiffusionProgress.cs new file mode 100644 index 00000000..57687a34 --- /dev/null +++ b/OnnxStack.StableDiffusion/Models/DiffusionProgress.cs @@ -0,0 +1,16 @@ +using Microsoft.ML.OnnxRuntime.Tensors; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace OnnxStack.StableDiffusion.Models +{ + public record DiffusionProgress(int ProgressValue, int ProgressMax, DenseTensor ProgressTensor) + { + public int SubProgressMax { get; set; } + public int SubProgressValue { get; set; } + } + +} diff --git a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs index 91e9b256..ff71e812 100644 --- a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs +++ b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs @@ -84,7 +84,7 @@ public bool IsModelLoaded(StableDiffusionModelSet modelSet) /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { return await DiffuseAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); } @@ -98,7 +98,7 @@ public async Task> GenerateAsync(StableDiffusionModelSet mode /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { return await GenerateAsync(model, prompt, options, progressCallback, cancellationToken) .ContinueWith(t => t.Result.ToImage(), cancellationToken) @@ -114,7 +114,7 @@ public async Task> GenerateAsImageAsync(StableDiffusionModelSet mo /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { var generateResult = await GenerateAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); if (!prompt.HasInputVideo) @@ -132,7 +132,7 @@ public async Task GenerateAsBytesAsync(StableDiffusionModelSet model, Pr /// The callback used to provide progess of the current InferenceSteps. /// The cancellation token. /// The diffusion result as - public async Task GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) + public async Task GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action progressCallback = null, CancellationToken cancellationToken = default) { var generateResult = await GenerateAsync(model, prompt, options, progressCallback, cancellationToken).ConfigureAwait(false); if (!prompt.HasInputVideo) @@ -152,7 +152,7 @@ public async Task GenerateAsStreamAsync(StableDiffusionModelSet model, P /// The progress callback. /// The cancellation token. /// - public IAsyncEnumerable GenerateBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable GenerateBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, CancellationToken cancellationToken = default) { return DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken); } @@ -168,7 +168,7 @@ public IAsyncEnumerable GenerateBatchAsync(StableDiffusionModelSet /// The progress callback. /// The cancellation token. /// - public async IAsyncEnumerable> GenerateBatchAsImageAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable> GenerateBatchAsImageAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) yield return result.ImageResult.ToImage(); @@ -185,7 +185,7 @@ public async IAsyncEnumerable> GenerateBatchAsImageAsync(StableDif /// The progress callback. /// The cancellation token. /// - public async IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) { @@ -207,7 +207,7 @@ public async IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionM /// The progress callback. /// The cancellation token. /// - public async IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { await foreach (var result in GenerateBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken)) { @@ -219,7 +219,7 @@ public async IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusion } - private async Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progress = null, CancellationToken cancellationToken = default) + private async Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progress = null, CancellationToken cancellationToken = default) { if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline)) throw new Exception("Pipeline not found or is unsupported"); @@ -236,7 +236,7 @@ private async Task> DiffuseAsync(StableDiffusionModelSet mode } - private IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progress = null, CancellationToken cancellationToken = default) + private IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progress = null, CancellationToken cancellationToken = default) { if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline)) throw new Exception("Pipeline not found or is unsupported"); diff --git a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs index 05142875..251b534d 100644 --- a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs +++ b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs @@ -3,6 +3,7 @@ using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Models; using OnnxStack.UI.Commands; using OnnxStack.UI.Models; using System; @@ -373,19 +374,19 @@ private Task GenerateResultAsync(byte[] imageBytes, PromptOptions p /// StableDiffusion progress callback. /// /// - private Action ProgressCallback() + private Action ProgressCallback() { - return (value, maximum) => + return (progress) => { App.UIInvoke(() => { if (_cancelationTokenSource.IsCancellationRequested) return; - if (ProgressMax != maximum) - ProgressMax = maximum; + if (ProgressMax != progress.ProgressMax) + ProgressMax = progress.ProgressMax; - ProgressValue = value; + ProgressValue = progress.ProgressValue; }); }; } diff --git a/OnnxStack.UI/Views/ImageToImageView.xaml.cs b/OnnxStack.UI/Views/ImageToImageView.xaml.cs index fe220cce..cfab7ac7 100644 --- a/OnnxStack.UI/Views/ImageToImageView.xaml.cs +++ b/OnnxStack.UI/Views/ImageToImageView.xaml.cs @@ -3,6 +3,7 @@ using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Models; using OnnxStack.UI.Commands; using OnnxStack.UI.Models; using System; @@ -340,24 +341,24 @@ private Task GenerateResultAsync(byte[] imageBytes, PromptOptions p /// StableDiffusion progress callback. /// /// - private Action ProgressCallback() + private Action ProgressCallback() { - return (value, maximum) => + return (progress) => { App.UIInvoke(() => { if (_cancelationTokenSource.IsCancellationRequested) return; - if (ProgressMax != maximum) - ProgressMax = maximum; + if (ProgressMax != progress.ProgressMax) + ProgressMax = progress.ProgressMax; - ProgressValue = value; + ProgressValue = progress.ProgressValue; }); }; } - + #region INotifyPropertyChanged public event PropertyChangedEventHandler PropertyChanged; public void NotifyPropertyChanged([CallerMemberName] string property = "") diff --git a/OnnxStack.UI/Views/TextToImageView.xaml.cs b/OnnxStack.UI/Views/TextToImageView.xaml.cs index 15cb6bbd..fb94673c 100644 --- a/OnnxStack.UI/Views/TextToImageView.xaml.cs +++ b/OnnxStack.UI/Views/TextToImageView.xaml.cs @@ -2,6 +2,7 @@ using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Models; using OnnxStack.UI.Commands; using OnnxStack.UI.Models; using System; @@ -315,19 +316,19 @@ private Task GenerateResultAsync(byte[] imageBytes, PromptOptions p /// StableDiffusion progress callback. /// /// - private Action ProgressCallback() + private Action ProgressCallback() { - return (value, maximum) => + return (progress) => { App.UIInvoke(() => { if (_cancelationTokenSource.IsCancellationRequested) return; - if (ProgressMax != maximum) - ProgressMax = maximum; + if (ProgressMax != progress.ProgressMax) + ProgressMax = progress.ProgressMax; - ProgressValue = value; + ProgressValue = progress.ProgressValue; }); }; } From 648df54827308951c25897069eee299043d475e2 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Mon, 1 Jan 2024 18:49:24 +1300 Subject: [PATCH 6/8] Make DirectML the default build --- OnnxStack.Console/OnnxStack.Console.csproj | 13 ++++----- OnnxStack.UI/OnnxStack.UI.csproj | 17 ++++------- OnnxStack.sln | 34 ++++------------------ 3 files changed, 18 insertions(+), 46 deletions(-) diff --git a/OnnxStack.Console/OnnxStack.Console.csproj b/OnnxStack.Console/OnnxStack.Console.csproj index 6adf143b..c5437e91 100644 --- a/OnnxStack.Console/OnnxStack.Console.csproj +++ b/OnnxStack.Console/OnnxStack.Console.csproj @@ -6,7 +6,7 @@ enable disable x64 - Debug;Release;Debug-DirectML;Debug-Cuda;Debug-TensorRT;Release-DirectML;Release-Cuda;Release-TensorRT + Debug;Release;Debug-Cuda;Debug-TensorRT;Release-Cuda;Release-TensorRT @@ -15,17 +15,16 @@ - - - - + + + + - + - diff --git a/OnnxStack.UI/OnnxStack.UI.csproj b/OnnxStack.UI/OnnxStack.UI.csproj index 9ad0eef3..030dd0f7 100644 --- a/OnnxStack.UI/OnnxStack.UI.csproj +++ b/OnnxStack.UI/OnnxStack.UI.csproj @@ -8,11 +8,7 @@ true true x64 - Debug;Release;Debug-DirectML;Debug-Cuda;Debug-TensorRT;Release-DirectML;Release-Cuda;Release-TensorRT - - - - True + Debug;Release;Debug-Cuda;Debug-TensorRT;Release-Cuda;Release-TensorRT @@ -53,17 +49,16 @@ - - - - + + + + - + - diff --git a/OnnxStack.sln b/OnnxStack.sln index 1a23fb3f..beac8caf 100644 --- a/OnnxStack.sln +++ b/OnnxStack.sln @@ -17,11 +17,9 @@ Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU Debug-Cuda|Any CPU = Debug-Cuda|Any CPU - Debug-DirectML|Any CPU = Debug-DirectML|Any CPU Debug-TensorRT|Any CPU = Debug-TensorRT|Any CPU Release|Any CPU = Release|Any CPU Release-Cuda|Any CPU = Release-Cuda|Any CPU - Release-DirectML|Any CPU = Release-DirectML|Any CPU Release-TensorRT|Any CPU = Release-TensorRT|Any CPU EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution @@ -29,80 +27,60 @@ Global {02404CEB-207F-4D19-894C-11D51394F1D5}.Debug|Any CPU.Build.0 = Debug|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Debug-Cuda|Any CPU.ActiveCfg = Debug|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Debug-Cuda|Any CPU.Build.0 = Debug|Any CPU - {02404CEB-207F-4D19-894C-11D51394F1D5}.Debug-DirectML|Any CPU.ActiveCfg = Debug|Any CPU - {02404CEB-207F-4D19-894C-11D51394F1D5}.Debug-DirectML|Any CPU.Build.0 = Debug|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Debug-TensorRT|Any CPU.ActiveCfg = Debug|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Debug-TensorRT|Any CPU.Build.0 = Debug|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Release|Any CPU.ActiveCfg = Release|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Release|Any CPU.Build.0 = Release|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Release-Cuda|Any CPU.ActiveCfg = Release|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Release-Cuda|Any CPU.Build.0 = Release|Any CPU - {02404CEB-207F-4D19-894C-11D51394F1D5}.Release-DirectML|Any CPU.ActiveCfg = Release|Any CPU - {02404CEB-207F-4D19-894C-11D51394F1D5}.Release-DirectML|Any CPU.Build.0 = Release|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Release-TensorRT|Any CPU.ActiveCfg = Release|Any CPU {02404CEB-207F-4D19-894C-11D51394F1D5}.Release-TensorRT|Any CPU.Build.0 = Release|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Debug|Any CPU.Build.0 = Debug|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Debug-Cuda|Any CPU.ActiveCfg = Debug|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Debug-Cuda|Any CPU.Build.0 = Debug|Any CPU - {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Debug-DirectML|Any CPU.ActiveCfg = Debug|Any CPU - {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Debug-DirectML|Any CPU.Build.0 = Debug|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Debug-TensorRT|Any CPU.ActiveCfg = Debug|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Debug-TensorRT|Any CPU.Build.0 = Debug|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release|Any CPU.ActiveCfg = Release|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release|Any CPU.Build.0 = Release|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-Cuda|Any CPU.ActiveCfg = Release|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-Cuda|Any CPU.Build.0 = Release|Any CPU - {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-DirectML|Any CPU.ActiveCfg = Release|Any CPU - {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-DirectML|Any CPU.Build.0 = Release|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-TensorRT|Any CPU.ActiveCfg = Release|Any CPU {EA1F61D0-490B-42EC-96F5-7DCCAB94457A}.Release-TensorRT|Any CPU.Build.0 = Release|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug|Any CPU.ActiveCfg = Debug-DirectML|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug|Any CPU.Build.0 = Debug-DirectML|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug|Any CPU.Build.0 = Debug|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-Cuda|Any CPU.ActiveCfg = Debug-Cuda|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-Cuda|Any CPU.Build.0 = Debug-Cuda|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-DirectML|Any CPU.ActiveCfg = Debug-DirectML|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-DirectML|Any CPU.Build.0 = Debug-DirectML|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-TensorRT|Any CPU.ActiveCfg = Debug-TensorRT|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Debug-TensorRT|Any CPU.Build.0 = Debug-TensorRT|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release|Any CPU.ActiveCfg = Release|Any CPU {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release|Any CPU.Build.0 = Release|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-Cuda|Any CPU.ActiveCfg = Release|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-Cuda|Any CPU.Build.0 = Release|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-DirectML|Any CPU.ActiveCfg = Release|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-DirectML|Any CPU.Build.0 = Release|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-TensorRT|Any CPU.ActiveCfg = Release|Any CPU - {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-TensorRT|Any CPU.Build.0 = Release|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-Cuda|Any CPU.ActiveCfg = Release-Cuda|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-Cuda|Any CPU.Build.0 = Release-Cuda|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-TensorRT|Any CPU.ActiveCfg = Release-TensorRT|Any CPU + {46A43C80-A440-4461-B7EB-81FA998FB24B}.Release-TensorRT|Any CPU.Build.0 = Release-TensorRT|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Debug|Any CPU.Build.0 = Debug|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Debug-Cuda|Any CPU.ActiveCfg = Debug-Cuda|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Debug-Cuda|Any CPU.Build.0 = Debug-Cuda|Any CPU - {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Debug-DirectML|Any CPU.ActiveCfg = Debug-DirectML|Any CPU - {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Debug-DirectML|Any CPU.Build.0 = Debug-DirectML|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Debug-TensorRT|Any CPU.ActiveCfg = Debug-TensorRT|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Debug-TensorRT|Any CPU.Build.0 = Debug-TensorRT|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Release|Any CPU.ActiveCfg = Release|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Release|Any CPU.Build.0 = Release|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Release-Cuda|Any CPU.ActiveCfg = Release-Cuda|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Release-Cuda|Any CPU.Build.0 = Release-Cuda|Any CPU - {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Release-DirectML|Any CPU.ActiveCfg = Release-DirectML|Any CPU - {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Release-DirectML|Any CPU.Build.0 = Release-DirectML|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Release-TensorRT|Any CPU.ActiveCfg = Release-TensorRT|Any CPU {85BB1855-8C3B-4049-A2DD-1130FA6CD846}.Release-TensorRT|Any CPU.Build.0 = Release-TensorRT|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Debug|Any CPU.Build.0 = Debug|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Debug-Cuda|Any CPU.ActiveCfg = Debug|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Debug-Cuda|Any CPU.Build.0 = Debug|Any CPU - {A33D08BF-7881-4910-8439-5AE46646C1DD}.Debug-DirectML|Any CPU.ActiveCfg = Debug|Any CPU - {A33D08BF-7881-4910-8439-5AE46646C1DD}.Debug-DirectML|Any CPU.Build.0 = Debug|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Debug-TensorRT|Any CPU.ActiveCfg = Debug|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Debug-TensorRT|Any CPU.Build.0 = Debug|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Release|Any CPU.ActiveCfg = Release|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Release|Any CPU.Build.0 = Release|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Release-Cuda|Any CPU.ActiveCfg = Release|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Release-Cuda|Any CPU.Build.0 = Release|Any CPU - {A33D08BF-7881-4910-8439-5AE46646C1DD}.Release-DirectML|Any CPU.ActiveCfg = Release|Any CPU - {A33D08BF-7881-4910-8439-5AE46646C1DD}.Release-DirectML|Any CPU.Build.0 = Release|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Release-TensorRT|Any CPU.ActiveCfg = Release|Any CPU {A33D08BF-7881-4910-8439-5AE46646C1DD}.Release-TensorRT|Any CPU.Build.0 = Release|Any CPU EndGlobalSection From 0ffc9ea132474db29c0123beadcc2c747fad9ba8 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Mon, 1 Jan 2024 19:41:44 +1300 Subject: [PATCH 7/8] VideoToVideo process prototype --- .../Examples/StableDiffusionBatch.cs | 9 +- OnnxStack.Core/Image/Extensions.cs | 73 ++++ OnnxStack.Core/Services/IVideoService.cs | 11 +- OnnxStack.Core/Services/VideoService.cs | 18 +- OnnxStack.Core/Video/Extensions.cs | 49 +++ OnnxStack.Core/Video/VideoInfo.cs | 2 +- OnnxStack.Core/Video/VideoInput.cs | 16 +- .../Config/PromptOptions.cs | 7 +- .../Config/SchedulerOptions.cs | 2 - .../Diffusers/DiffuserBase.cs | 71 +++- .../LatentConsistency/VideoDiffuser.cs | 180 -------- .../Enums/DiffuserType.cs | 5 +- .../Helpers/ImageHelpers.cs | 43 +- .../Helpers/TensorHelper.cs | 3 + .../Models/DiffusionProgress.cs | 11 +- OnnxStack.StableDiffusion/Registration.cs | 2 - .../Services/StableDiffusionService.cs | 99 ++++- OnnxStack.UI/MainWindow.xaml | 15 + OnnxStack.UI/MainWindow.xaml.cs | 3 +- OnnxStack.UI/Models/OnnxStackUIConfig.cs | 6 +- OnnxStack.UI/Models/PromptOptionsModel.cs | 14 + OnnxStack.UI/Models/VideoInputModel.cs | 13 + OnnxStack.UI/OnnxStack.UI.csproj | 9 + OnnxStack.UI/UserControls/PromptControl.xaml | 14 +- .../UserControls/PromptControl.xaml.cs | 15 +- .../UserControls/VideoInputControl.xaml | 78 ++++ .../UserControls/VideoInputControl.xaml.cs | 182 ++++++++ .../UserControls/VideoResultControl.xaml | 88 ++++ .../UserControls/VideoResultControl.xaml.cs | 206 +++++++++ OnnxStack.UI/Views/ImageInpaintView.xaml.cs | 6 +- OnnxStack.UI/Views/ImageToImageView.xaml.cs | 6 +- OnnxStack.UI/Views/TextToImageView.xaml.cs | 6 +- OnnxStack.UI/Views/VideoToVideoView.xaml | 117 +++++ OnnxStack.UI/Views/VideoToVideoView.xaml.cs | 398 ++++++++++++++++++ 34 files changed, 1491 insertions(+), 286 deletions(-) create mode 100644 OnnxStack.Core/Image/Extensions.cs create mode 100644 OnnxStack.Core/Video/Extensions.cs delete mode 100644 OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs create mode 100644 OnnxStack.UI/Models/VideoInputModel.cs create mode 100644 OnnxStack.UI/UserControls/VideoInputControl.xaml create mode 100644 OnnxStack.UI/UserControls/VideoInputControl.xaml.cs create mode 100644 OnnxStack.UI/UserControls/VideoResultControl.xaml create mode 100644 OnnxStack.UI/UserControls/VideoResultControl.xaml.cs create mode 100644 OnnxStack.UI/Views/VideoToVideoView.xaml create mode 100644 OnnxStack.UI/Views/VideoToVideoView.xaml.cs diff --git a/OnnxStack.Console/Examples/StableDiffusionBatch.cs b/OnnxStack.Console/Examples/StableDiffusionBatch.cs index fecf4b80..b6cb2cb6 100644 --- a/OnnxStack.Console/Examples/StableDiffusionBatch.cs +++ b/OnnxStack.Console/Examples/StableDiffusionBatch.cs @@ -1,4 +1,5 @@ -using OnnxStack.StableDiffusion.Common; +using OnnxStack.Core.Image; +using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Helpers; @@ -61,11 +62,11 @@ public async Task RunAsync() var batchIndex = 0; var callback = (DiffusionProgress progress) => { - batchIndex = progress.ProgressValue; - OutputHelpers.WriteConsole($"Image: {progress.ProgressValue}/{progress.ProgressMax} - Step: {progress.SubProgressValue}/{progress.SubProgressMax}", ConsoleColor.Cyan); + batchIndex = progress.BatchValue; + OutputHelpers.WriteConsole($"Image: {progress.BatchValue}/{progress.BatchMax} - Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Cyan); }; - await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback)) + await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, default)) { var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png"); var image = result.ImageResult.ToImage(); diff --git a/OnnxStack.Core/Image/Extensions.cs b/OnnxStack.Core/Image/Extensions.cs new file mode 100644 index 00000000..f3eadf7e --- /dev/null +++ b/OnnxStack.Core/Image/Extensions.cs @@ -0,0 +1,73 @@ +using Microsoft.ML.OnnxRuntime.Tensors; +using SixLabors.ImageSharp.PixelFormats; +using SixLabors.ImageSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.IO; + +namespace OnnxStack.Core.Image +{ + public static class Extensions + { + public static Image ToImage(this DenseTensor imageTensor) + { + var height = imageTensor.Dimensions[2]; + var width = imageTensor.Dimensions[3]; + var hasAlpha = imageTensor.Dimensions[1] == 4; + var result = new Image(width, height); + for (var y = 0; y < height; y++) + { + for (var x = 0; x < width; x++) + { + result[x, y] = new Rgba32( + CalculateByte(imageTensor, 0, y, x), + CalculateByte(imageTensor, 1, y, x), + CalculateByte(imageTensor, 2, y, x), + hasAlpha ? CalculateByte(imageTensor, 3, y, x) : byte.MaxValue + ); + } + } + return result; + } + + /// + /// Converts to image byte array. + /// + /// The image tensor. + /// + public static byte[] ToImageBytes(this DenseTensor imageTensor) + { + using (var image = imageTensor.ToImage()) + using (var memoryStream = new MemoryStream()) + { + image.SaveAsPng(memoryStream); + return memoryStream.ToArray(); + } + } + + /// + /// Converts to image byte array. + /// + /// The image tensor. + /// + public static async Task ToImageBytesAsync(this DenseTensor imageTensor) + { + using (var image = imageTensor.ToImage()) + using (var memoryStream = new MemoryStream()) + { + await image.SaveAsPngAsync(memoryStream); + return memoryStream.ToArray(); + } + } + + + private static byte CalculateByte(Tensor imageTensor, int index, int y, int x) + { + return (byte)Math.Round(Math.Clamp(imageTensor[0, index, y, x] / 2 + 0.5, 0, 1) * 255); + } + + } +} diff --git a/OnnxStack.Core/Services/IVideoService.cs b/OnnxStack.Core/Services/IVideoService.cs index 465aa501..52be460c 100644 --- a/OnnxStack.Core/Services/IVideoService.cs +++ b/OnnxStack.Core/Services/IVideoService.cs @@ -1,4 +1,5 @@ -using OnnxStack.Core.Video; +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Video; using System.Collections.Generic; using System.IO; using System.Threading; @@ -87,6 +88,14 @@ public interface IVideoService /// Task CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default); + // + /// Creates and MP4 video from a collection of PNG images. + /// + /// The video frames. + /// The video FPS. + /// The cancellation token. + /// + Task CreateVideoAsync(DenseTensor videoTensor, float videoFPS, CancellationToken cancellationToken = default); /// /// Streams frames as PNG as they are processed from a video source diff --git a/OnnxStack.Core/Services/VideoService.cs b/OnnxStack.Core/Services/VideoService.cs index 08f7a2eb..6d4d7927 100644 --- a/OnnxStack.Core/Services/VideoService.cs +++ b/OnnxStack.Core/Services/VideoService.cs @@ -1,4 +1,5 @@ using FFMpegCore; +using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core.Config; using OnnxStack.Core.Video; using System; @@ -94,6 +95,20 @@ public async Task CreateVideoAsync(VideoFrames videoFrames, Cancell } + /// + /// Creates and MP4 video from a collection of PNG images. + /// + /// The video tensor. + /// The video FPS. + /// The cancellation token. + /// + public async Task CreateVideoAsync(DenseTensor videoTensor, float videoFPS, CancellationToken cancellationToken = default) + { + var videoFrames = await videoTensor.ToVideoFramesAsBytesAsync().ToListAsync(cancellationToken); + return await CreateVideoInternalAsync(videoFrames, videoFPS, cancellationToken); + } + + /// /// Creates and MP4 video from a collection of PNG images. /// @@ -141,6 +156,7 @@ public async Task CreateFramesAsync(byte[] videoBytes, float videoF { var videoInfo = await GetVideoInfoAsync(videoBytes, cancellationToken); var videoFrames = await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken); + videoInfo = videoInfo with { FPS = videoFPS }; return new VideoFrames(videoInfo, videoFrames); } @@ -190,7 +206,7 @@ public IAsyncEnumerable StreamFramesAsync(byte[] videoBytes, float targe /// private async Task GetVideoInfoInternalAsync(MemoryStream videoStream, CancellationToken cancellationToken = default) { - var result = await FFProbe.AnalyseAsync(videoStream, cancellationToken: cancellationToken).ConfigureAwait(false); + var result = await FFProbe.AnalyseAsync(videoStream).ConfigureAwait(false); return new VideoInfo(result.PrimaryVideoStream.Width, result.PrimaryVideoStream.Height, result.Duration, (int)result.PrimaryVideoStream.FrameRate); } diff --git a/OnnxStack.Core/Video/Extensions.cs b/OnnxStack.Core/Video/Extensions.cs new file mode 100644 index 00000000..d32a2c96 --- /dev/null +++ b/OnnxStack.Core/Video/Extensions.cs @@ -0,0 +1,49 @@ +using Microsoft.ML.OnnxRuntime.Tensors; +using OnnxStack.Core.Image; +using SixLabors.ImageSharp; +using SixLabors.ImageSharp.PixelFormats; +using System.Collections.Generic; + +namespace OnnxStack.Core.Video +{ + public static class Extensions + { + public static IEnumerable> ToVideoFrames(this DenseTensor videoTensor) + { + var count = videoTensor.Dimensions[0]; + var dimensions = videoTensor.Dimensions.ToArray(); + dimensions[0] = 1; + + var newLength = (int)videoTensor.Length / count; + for (int i = 0; i < count; i++) + { + var start = i * newLength; + yield return new DenseTensor(videoTensor.Buffer.Slice(start, newLength), dimensions); + } + } + + public static IEnumerable ToVideoFramesAsBytes(this DenseTensor videoTensor) + { + foreach (var frame in videoTensor.ToVideoFrames()) + { + yield return frame.ToImageBytes(); + } + } + + public static async IAsyncEnumerable ToVideoFramesAsBytesAsync(this DenseTensor videoTensor) + { + foreach (var frame in videoTensor.ToVideoFrames()) + { + yield return await frame.ToImageBytesAsync(); + } + } + + public static IEnumerable> ToVideoFramesAsImage(this DenseTensor videoTensor) + { + foreach (var frame in videoTensor.ToVideoFrames()) + { + yield return frame.ToImage(); + } + } + } +} diff --git a/OnnxStack.Core/Video/VideoInfo.cs b/OnnxStack.Core/Video/VideoInfo.cs index d3a91f03..17f3fb39 100644 --- a/OnnxStack.Core/Video/VideoInfo.cs +++ b/OnnxStack.Core/Video/VideoInfo.cs @@ -2,5 +2,5 @@ namespace OnnxStack.Core.Video { - public record VideoInfo(int Width, int Height, TimeSpan Duration, int FPS); + public record VideoInfo(int Width, int Height, TimeSpan Duration, float FPS); } diff --git a/OnnxStack.Core/Video/VideoInput.cs b/OnnxStack.Core/Video/VideoInput.cs index ef2b268e..8b1b4506 100644 --- a/OnnxStack.Core/Video/VideoInput.cs +++ b/OnnxStack.Core/Video/VideoInput.cs @@ -29,6 +29,12 @@ public VideoInput() { } /// The video tensor. public VideoInput(DenseTensor videoTensor) => VideoTensor = videoTensor; + /// + /// Initializes a new instance of the class. + /// + /// The video frames. + public VideoInput(VideoFrames videoFrames) => VideoFrames = videoFrames; + /// /// Gets the video bytes. @@ -51,6 +57,13 @@ public VideoInput() { } public DenseTensor VideoTensor { get; set; } + /// + /// Gets or sets the video frames. + /// + [JsonIgnore] + public VideoFrames VideoFrames { get; set; } + + /// /// Gets a value indicating whether this instance has video. /// @@ -60,6 +73,7 @@ public VideoInput() { } [JsonIgnore] public bool HasVideo => VideoBytes != null || VideoStream != null - || VideoTensor != null; + || VideoTensor != null + || VideoFrames != null; } } diff --git a/OnnxStack.StableDiffusion/Config/PromptOptions.cs b/OnnxStack.StableDiffusion/Config/PromptOptions.cs index 5c0ab60b..d85cc1eb 100644 --- a/OnnxStack.StableDiffusion/Config/PromptOptions.cs +++ b/OnnxStack.StableDiffusion/Config/PromptOptions.cs @@ -20,9 +20,12 @@ public class PromptOptions public InputImage InputImageMask { get; set; } - public VideoFrames InputVideo { get; set; } + public VideoInput InputVideo { get; set; } - public bool HasInputVideo => InputVideo?.Frames?.Count > 0; + public float VideoInputFPS { get; set; } + public float VideoOutputFPS { get; set; } + + public bool HasInputVideo => InputVideo?.HasVideo ?? false; public bool HasInputImage => InputImage?.HasImage ?? false; public bool HasInputImageMask => InputImageMask?.HasImage ?? false; } diff --git a/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs b/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs index 50e33555..3b258a0c 100644 --- a/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs +++ b/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs @@ -84,8 +84,6 @@ public record SchedulerOptions public float AestheticScore { get; set; } = 6f; public float AestheticNegativeScore { get; set; } = 2.5f; - public float VideoFPS { get; set; } - public bool IsKarrasScheduler { get diff --git a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs index 354d2a4c..39050ced 100644 --- a/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs +++ b/OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs @@ -3,6 +3,7 @@ using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core; using OnnxStack.Core.Config; +using OnnxStack.Core.Image; using OnnxStack.Core.Model; using OnnxStack.Core.Services; using OnnxStack.StableDiffusion.Common; @@ -113,15 +114,38 @@ public virtual async Task> DiffuseAsync(StableDiffusionModelS // Process prompts var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); + // If video input, process frames + if (promptOptions.HasInputVideo) + { + var frameIndex = 0; + DenseTensor videoTensor = null; + var videoFrames = promptOptions.InputVideo.VideoFrames.Frames; + var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex); + foreach (var videoFrame in videoFrames) + { + frameIndex++; + promptOptions.InputImage = new InputImage(videoFrame); + var frameResultTensor = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken); + + // Frame Progress + ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor); + + // Concatenate frame + videoTensor = videoTensor.Concatenate(frameResultTensor); + } + + _logger?.LogEnd($"Diffuse complete", diffuseTime); + return videoTensor; + } + // Run Scheduler steps var schedulerResult = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); - _logger?.LogEnd($"Diffuse complete", diffuseTime); - return schedulerResult; } + /// /// Runs the stable diffusion batch loop /// @@ -152,15 +176,11 @@ public virtual async IAsyncEnumerable DiffuseBatchAsync(StableDiffu var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions); var batchIndex = 1; - var schedulerCallback = (DiffusionProgress progress) => progressCallback?.Invoke(new DiffusionProgress(batchIndex, batchSchedulerOptions.Count, progress.ProgressTensor) - { - SubProgressMax = progress.ProgressMax, - SubProgressValue = progress.ProgressValue, - }); + var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex); foreach (var batchSchedulerOption in batchSchedulerOptions) { var diffuseTime = _logger?.LogBegin("Diffuse starting..."); - yield return new BatchResult(batchSchedulerOption, await SchedulerStepAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken)); + yield return new BatchResult(batchSchedulerOption, await SchedulerStepAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken)); _logger?.LogEnd($"Diffuse complete", diffuseTime); batchIndex++; } @@ -264,9 +284,14 @@ protected static IReadOnlyList CreateInputParameters(params Name /// The progress. /// The progress maximum. /// The output. - protected void ReportProgress(Action progressCallback, int progress, int progressMax, DenseTensor output) + protected void ReportProgress(Action progressCallback, int progress, int progressMax, DenseTensor progressTensor) { - progressCallback?.Invoke(new DiffusionProgress(progress, progressMax, output)); + progressCallback?.Invoke(new DiffusionProgress + { + StepMax = progressMax, + StepValue = progress, + StepTensor = progressTensor + }); } @@ -279,13 +304,31 @@ protected void ReportProgress(Action progressCallback, int pr /// The sub progress. /// The sub progress maximum. /// The output. - protected void ReportProgress(Action progressCallback, int progress, int progressMax, int subProgress, int subProgressMax, DenseTensor output) + protected void ReportBatchProgress(Action progressCallback, int progress, int progressMax, DenseTensor progressTensor) + { + progressCallback?.Invoke(new DiffusionProgress + { + BatchMax = progressMax, + BatchValue = progress, + BatchTensor = progressTensor + }); + } + + + private static Action CreateBatchCallback(Action progressCallback, int batchCount, Func batchIndex) { - progressCallback?.Invoke(new DiffusionProgress(progress, progressMax, output) + if (progressCallback == null) + return progressCallback; + + return (DiffusionProgress progress) => progressCallback?.Invoke(new DiffusionProgress { - SubProgressMax = subProgressMax, - SubProgressValue = subProgress, + StepMax = progress.StepMax, + StepValue = progress.StepValue, + StepTensor = progress.StepTensor, + BatchMax = batchCount, + BatchValue = batchIndex() }); } + } } diff --git a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs deleted file mode 100644 index 5b810497..00000000 --- a/OnnxStack.StableDiffusion/Diffusers/LatentConsistency/VideoDiffuser.cs +++ /dev/null @@ -1,180 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.ML.OnnxRuntime.Tensors; -using OnnxStack.Core; -using OnnxStack.Core.Config; -using OnnxStack.Core.Model; -using OnnxStack.Core.Services; -using OnnxStack.StableDiffusion.Common; -using OnnxStack.StableDiffusion.Config; -using OnnxStack.StableDiffusion.Enums; -using OnnxStack.StableDiffusion.Helpers; -using OnnxStack.StableDiffusion.Models; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; - -namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency -{ - public sealed class VideoDiffuser : LatentConsistencyDiffuser - { - /// - /// Initializes a new instance of the class. - /// - /// The configuration. - /// The onnx model service. - public VideoDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger logger) - : base(onnxModelService, promptService, logger) { } - - - /// - /// Gets the type of the diffuser. - /// - public override DiffuserType DiffuserType => DiffuserType.VideoToVideo; - - - /// - /// Runs the scheduler steps. - /// - /// The model options. - /// The prompt options. - /// The scheduler options. - /// The prompt embeddings. - /// if set to true [perform guidance]. - /// The progress callback. - /// The cancellation token. - /// - protected override async Task> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action progressCallback = null, CancellationToken cancellationToken = default) - { - var frameIndex = 0; - DenseTensor resultTensor = null; - var videoFrames = promptOptions.InputVideo.Frames; - foreach (var videoFrame in videoFrames) - { - // Get Scheduler - frameIndex++; - using (var scheduler = GetScheduler(schedulerOptions)) - { - // Get timesteps - var timesteps = GetTimesteps(schedulerOptions, scheduler); - - // Create latent sample - var latents = await PrepareFrameLatentsAsync(modelOptions, videoFrame, schedulerOptions, scheduler, timesteps); - - // Get Guidance Scale Embedding - var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale); - - // Denoised result - DenseTensor denoised = null; - - // Get Model metadata - var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet); - - // Loop though the timesteps - var step = 0; - foreach (var timestep in timesteps) - { - step++; - var stepTime = Stopwatch.GetTimestamp(); - cancellationToken.ThrowIfCancellationRequested(); - - // Create input tensor. - var inputTensor = scheduler.ScaleInput(latents, timestep); - var timestepTensor = CreateTimestepTensor(timestep); - - var outputChannels = 1; - var outputDimension = schedulerOptions.GetScaledDimension(outputChannels); - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) - { - inferenceParameters.AddInputTensor(inputTensor); - inferenceParameters.AddInputTensor(timestepTensor); - inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds); - inferenceParameters.AddInputTensor(guidanceEmbeddings); - inferenceParameters.AddOutputBuffer(outputDimension); - - var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters); - using (var result = results.First()) - { - var noisePred = result.ToDenseTensor(); - - // Scheduler Step - var schedulerResult = scheduler.Step(noisePred, timestep, latents); - - latents = schedulerResult.Result; - denoised = schedulerResult.SampleData; - } - } - - // Step Progress - ReportProgress(progressCallback, frameIndex, videoFrames.Count, step, timesteps.Count, latents); - _logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime); - } - - // Decode Latents - var frameResultTensor = await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, denoised); - - // Frame Progress - ReportProgress(progressCallback, frameIndex, videoFrames.Count, step, timesteps.Count, frameResultTensor); - - // Concatenate tensor frame - resultTensor = resultTensor is null - ? frameResultTensor - : resultTensor.Concatenate(frameResultTensor); - } - } - return resultTensor; - } - - - /// - /// Gets the timesteps. - /// - /// The prompt. - /// The options. - /// The scheduler. - /// - protected override IReadOnlyList GetTimesteps(SchedulerOptions options, IScheduler scheduler) - { - var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps); - var start = Math.Max(options.InferenceSteps - inittimestep, 0); - return scheduler.Timesteps.Skip(start).ToList(); - } - - - /// - /// Prepares the latents for inference. - /// - /// The prompt. - /// The options. - /// The scheduler. - /// - private async Task> PrepareFrameLatentsAsync(StableDiffusionModelSet model, byte[] videoFrame, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) - { - var imageTensor = ImageHelpers.TensorFromBytes(videoFrame, new[] { 1, 3, options.Height, options.Width }); - - //TODO: Model Config, Channels - var outputDimension = options.GetScaledDimension(); - var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder); - using (var inferenceParameters = new OnnxInferenceParameters(metadata)) - { - inferenceParameters.AddInputTensor(imageTensor); - inferenceParameters.AddOutputBuffer(outputDimension); - - var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters); - using (var result = results.First()) - { - var outputResult = result.ToDenseTensor(); - var scaledSample = outputResult.MultiplyBy(model.ScaleFactor); - return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps); - } - } - } - - protected override Task> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList timesteps) - { - throw new NotImplementedException(); - } - } -} diff --git a/OnnxStack.StableDiffusion/Enums/DiffuserType.cs b/OnnxStack.StableDiffusion/Enums/DiffuserType.cs index dd4c997a..96a127f1 100644 --- a/OnnxStack.StableDiffusion/Enums/DiffuserType.cs +++ b/OnnxStack.StableDiffusion/Enums/DiffuserType.cs @@ -17,9 +17,6 @@ public enum DiffuserType ImageInpaintLegacy = 3, [Description("Image To Animation")] - ImageToAnimation = 4, - - [Description("Video To Video")] - VideoToVideo = 5 + ImageToAnimation = 4 } } diff --git a/OnnxStack.StableDiffusion/Helpers/ImageHelpers.cs b/OnnxStack.StableDiffusion/Helpers/ImageHelpers.cs index 2fdf6983..2db76098 100644 --- a/OnnxStack.StableDiffusion/Helpers/ImageHelpers.cs +++ b/OnnxStack.StableDiffusion/Helpers/ImageHelpers.cs @@ -6,37 +6,12 @@ using SixLabors.ImageSharp.Processing; using System; using System.IO; +using System.Threading.Tasks; namespace OnnxStack.StableDiffusion.Helpers { public static class ImageHelpers { - /// - /// Converts to image. - /// - /// The image tensor. - /// - public static Image ToImage(this DenseTensor imageTensor) - { - var height = imageTensor.Dimensions[2]; - var width = imageTensor.Dimensions[3]; - var hasAlpha = imageTensor.Dimensions[1] == 4; - var result = new Image(width, height); - for (var y = 0; y < height; y++) - { - for (var x = 0; x < width; x++) - { - result[x, y] = new Rgba32( - CalculateByte(imageTensor, 0, y, x), - CalculateByte(imageTensor, 1, y, x), - CalculateByte(imageTensor, 2, y, x), - hasAlpha ? CalculateByte(imageTensor, 3, y, x) : byte.MaxValue - ); - } - } - return result; - } - /// /// Converts to image. @@ -58,22 +33,6 @@ public static Image ToImage(this InputImage inputImage) } - /// - /// Converts to image byte array. - /// - /// The image tensor. - /// - public static byte[] ToImageBytes(this DenseTensor imageTensor) - { - using (var image = imageTensor.ToImage()) - using (var memoryStream = new MemoryStream()) - { - image.SaveAsPng(memoryStream); - return memoryStream.ToArray(); - } - } - - /// /// Converts to image memory stream. /// diff --git a/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs b/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs index 60fbda3c..5e540e3d 100644 --- a/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs +++ b/OnnxStack.StableDiffusion/Helpers/TensorHelper.cs @@ -257,6 +257,9 @@ public static DenseTensor Clip(this DenseTensor tensor, float minV /// Only axis 0 is supported public static DenseTensor Concatenate(this DenseTensor tensor1, DenseTensor tensor2, int axis = 0) { + if (tensor1 == null) + return tensor2.ToDenseTensor(); + if (axis != 0 && axis != 2) throw new NotImplementedException("Only axis 0, 2 is supported"); diff --git a/OnnxStack.StableDiffusion/Models/DiffusionProgress.cs b/OnnxStack.StableDiffusion/Models/DiffusionProgress.cs index 57687a34..9d7ac935 100644 --- a/OnnxStack.StableDiffusion/Models/DiffusionProgress.cs +++ b/OnnxStack.StableDiffusion/Models/DiffusionProgress.cs @@ -7,10 +7,15 @@ namespace OnnxStack.StableDiffusion.Models { - public record DiffusionProgress(int ProgressValue, int ProgressMax, DenseTensor ProgressTensor) + public record DiffusionProgress(string Message = default) { - public int SubProgressMax { get; set; } - public int SubProgressValue { get; set; } + public int BatchMax { get; set; } + public int BatchValue { get; set; } + public DenseTensor BatchTensor { get; set; } + + public int StepMax { get; set; } + public int StepValue { get; set; } + public DenseTensor StepTensor { get; set; } } } diff --git a/OnnxStack.StableDiffusion/Registration.cs b/OnnxStack.StableDiffusion/Registration.cs index 17692662..2ce1a369 100644 --- a/OnnxStack.StableDiffusion/Registration.cs +++ b/OnnxStack.StableDiffusion/Registration.cs @@ -4,7 +4,6 @@ using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Diffusers; -using OnnxStack.StableDiffusion.Diffusers.LatentConsistency; using OnnxStack.StableDiffusion.Pipelines; using OnnxStack.StableDiffusion.Services; using SixLabors.ImageSharp; @@ -72,7 +71,6 @@ private static void RegisterServices(this IServiceCollection serviceCollection) serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); - serviceCollection.AddSingleton(); //LatentConsistencyXL serviceCollection.AddSingleton(); diff --git a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs index ff71e812..e44d8284 100644 --- a/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs +++ b/OnnxStack.StableDiffusion/Services/StableDiffusionService.cs @@ -1,5 +1,6 @@ using Microsoft.ML.OnnxRuntime.Tensors; using OnnxStack.Core; +using OnnxStack.Core.Image; using OnnxStack.Core.Services; using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; @@ -120,7 +121,7 @@ public async Task GenerateAsBytesAsync(StableDiffusionModelSet model, Pr if (!prompt.HasInputVideo) return generateResult.ToImageBytes(); - return await GetVideoResultAsBytesAsync(options, generateResult, cancellationToken).ConfigureAwait(false); + return await GenerateVideoResultAsBytesAsync(generateResult, prompt.VideoOutputFPS, progressCallback, cancellationToken).ConfigureAwait(false); } @@ -138,7 +139,7 @@ public async Task GenerateAsStreamAsync(StableDiffusionModelSet model, P if (!prompt.HasInputVideo) return generateResult.ToImageStream(); - return await GetVideoResultAsStreamAsync(options, generateResult, cancellationToken).ConfigureAwait(false); + return await GenerateVideoResultAsStreamAsync(generateResult, prompt.VideoOutputFPS, progressCallback, cancellationToken).ConfigureAwait(false); } @@ -192,7 +193,7 @@ public async IAsyncEnumerable GenerateBatchAsBytesAsync(StableDiffusionM if (!promptOptions.HasInputVideo) yield return result.ImageResult.ToImageBytes(); - yield return await GetVideoResultAsBytesAsync(schedulerOptions, result.ImageResult, cancellationToken).ConfigureAwait(false); + yield return await GenerateVideoResultAsBytesAsync(result.ImageResult, promptOptions.VideoOutputFPS, progressCallback, cancellationToken).ConfigureAwait(false); } } @@ -214,11 +215,27 @@ public async IAsyncEnumerable GenerateBatchAsStreamAsync(StableDiffusion if (!promptOptions.HasInputVideo) yield return result.ImageResult.ToImageStream(); - yield return await GetVideoResultAsStreamAsync(schedulerOptions, result.ImageResult, cancellationToken).ConfigureAwait(false); + yield return await GenerateVideoResultAsStreamAsync(result.ImageResult, promptOptions.VideoOutputFPS, progressCallback, cancellationToken).ConfigureAwait(false); } } + /// + /// Runs the diffusion process + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The progress. + /// The cancellation token. + /// + /// + /// Pipeline not found or is unsupported + /// or + /// Diffuser not found or is unsupported + /// or + /// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline. + /// private async Task> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action progress = null, CancellationToken cancellationToken = default) { if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline)) @@ -232,11 +249,29 @@ private async Task> DiffuseAsync(StableDiffusionModelSet mode if (!schedulerSupported) throw new Exception($"Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline."); + await GenerateInputVideoFrames(promptOptions, progress); return await diffuser.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progress, cancellationToken); } - private IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progress = null, CancellationToken cancellationToken = default) + /// + /// Runs the batch diffusion process. + /// + /// The model options. + /// The prompt options. + /// The scheduler options. + /// The batch options. + /// The progress. + /// The cancellation token. + /// + /// + /// Pipeline not found or is unsupported + /// or + /// Diffuser not found or is unsupported + /// or + /// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline. + /// + private async IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action progress = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (!_pipelines.TryGetValue(modelOptions.PipelineType, out var pipeline)) throw new Exception("Pipeline not found or is unsupported"); @@ -249,24 +284,60 @@ private IAsyncEnumerable DiffuseBatchAsync(StableDiffusionModelSet if (!schedulerSupported) throw new Exception($"Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline."); - return diffuser.DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progress, cancellationToken); + await GenerateInputVideoFrames(promptOptions, progress); + await foreach (var result in diffuser.DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progress, cancellationToken)) + { + yield return result; + } } - private async Task GetVideoResultAsBytesAsync(SchedulerOptions options, DenseTensor tensorResult, CancellationToken cancellationToken = default) - { - var frameTensors = tensorResult - .Split(tensorResult.Dimensions[0]) - .Select(x => x.ToImageBytes()); - var videoResult = await _videoService.CreateVideoAsync(frameTensors, options.VideoFPS, cancellationToken); + /// + /// Generates the video result as bytes. + /// + /// The options. + /// The video tensor. + /// The progress. + /// The cancellation token. + /// + private async Task GenerateVideoResultAsBytesAsync(DenseTensor videoTensor, float videoFPS, Action progress = null, CancellationToken cancellationToken = default) + { + progress?.Invoke(new DiffusionProgress("Generating Video Result...")); + var videoResult = await _videoService.CreateVideoAsync(videoTensor, videoFPS, cancellationToken); return videoResult.Data; } - private async Task GetVideoResultAsStreamAsync(SchedulerOptions options, DenseTensor tensorResult, CancellationToken cancellationToken = default) + + /// + /// Generates the video result as stream. + /// + /// The options. + /// The video tensor. + /// The progress. + /// The cancellation token. + /// + private async Task GenerateVideoResultAsStreamAsync(DenseTensor videoTensor, float videoFPS, Action progress = null, CancellationToken cancellationToken = default) { - return new MemoryStream(await GetVideoResultAsBytesAsync(options, tensorResult, cancellationToken)); + return new MemoryStream(await GenerateVideoResultAsBytesAsync(videoTensor, videoFPS, progress, cancellationToken)); } + /// + /// Generates the input video frames. + /// + /// The prompt options. + /// The progress. + private async Task GenerateInputVideoFrames(PromptOptions promptOptions, Action progress) + { + if (!promptOptions.HasInputVideo || promptOptions.InputVideo.VideoFrames is not null) + return; + + // Already has VideoFrames + if (promptOptions.InputVideo.VideoFrames is not null) + return; + + progress?.Invoke(new DiffusionProgress($"Generating video frames @ {promptOptions.VideoInputFPS}fps")); + promptOptions.InputVideo.VideoFrames = await _videoService.CreateFramesAsync(promptOptions.InputVideo, promptOptions.VideoInputFPS); + } } } diff --git a/OnnxStack.UI/MainWindow.xaml b/OnnxStack.UI/MainWindow.xaml index 2795b78a..913b522b 100644 --- a/OnnxStack.UI/MainWindow.xaml +++ b/OnnxStack.UI/MainWindow.xaml @@ -67,6 +67,21 @@ + + + + + + + + + + + + + + + diff --git a/OnnxStack.UI/MainWindow.xaml.cs b/OnnxStack.UI/MainWindow.xaml.cs index ad0d2034..94f9d6ea 100644 --- a/OnnxStack.UI/MainWindow.xaml.cs +++ b/OnnxStack.UI/MainWindow.xaml.cs @@ -107,7 +107,8 @@ private enum TabId TextToImage = 0, ImageToImage = 1, ImageInpaint = 2, - Upscaler = 3 + VideoToVideo = 3, + Upscaler = 4 } diff --git a/OnnxStack.UI/Models/OnnxStackUIConfig.cs b/OnnxStack.UI/Models/OnnxStackUIConfig.cs index 24aedcc8..49fac02a 100644 --- a/OnnxStack.UI/Models/OnnxStackUIConfig.cs +++ b/OnnxStack.UI/Models/OnnxStackUIConfig.cs @@ -21,12 +21,12 @@ public class OnnxStackUIConfig : IConfigSection public IEnumerable GetSupportedExecutionProviders() { -#if DEBUG_DIRECTML || RELEASE_DIRECTML - yield return ExecutionProvider.DirectML; -#elif DEBUG_CUDA || RELEASE_CUDA +#if DEBUG_CUDA || RELEASE_CUDA yield return ExecutionProvider.Cuda; #elif DEBUG_TENSORRT || RELEASE_TENSORRT yield return ExecutionProvider.TensorRT; +#else + yield return ExecutionProvider.DirectML; #endif yield return ExecutionProvider.Cpu; } diff --git a/OnnxStack.UI/Models/PromptOptionsModel.cs b/OnnxStack.UI/Models/PromptOptionsModel.cs index 21867bef..d3e030a8 100644 --- a/OnnxStack.UI/Models/PromptOptionsModel.cs +++ b/OnnxStack.UI/Models/PromptOptionsModel.cs @@ -9,6 +9,8 @@ public class PromptOptionsModel : INotifyPropertyChanged private string _prompt; private string _negativePrompt; private bool _hasChanged; + private float _videoInputFPS; + private float _videoOutputFPS; [Required] [StringLength(512, MinimumLength = 1)] @@ -31,6 +33,18 @@ public bool HasChanged set { _hasChanged = value; NotifyPropertyChanged(); } } + public float VideoInputFPS + { + get { return _videoInputFPS; } + set { _videoInputFPS = value; NotifyPropertyChanged(); } + } + + public float VideoOutputFPS + { + get { return _videoOutputFPS; } + set { _videoOutputFPS = value; NotifyPropertyChanged(); } + } + #region INotifyPropertyChanged public event PropertyChangedEventHandler PropertyChanged; diff --git a/OnnxStack.UI/Models/VideoInputModel.cs b/OnnxStack.UI/Models/VideoInputModel.cs new file mode 100644 index 00000000..20e06a22 --- /dev/null +++ b/OnnxStack.UI/Models/VideoInputModel.cs @@ -0,0 +1,13 @@ +using OnnxStack.Core.Video; +using System.Text.Json.Serialization; + +namespace OnnxStack.UI.Models +{ + public class VideoInputModel + { + public VideoInfo VideoInfo { get; set; } + public string FileName { get; set; } + public byte[] VideoBytes { get; set; } + } + +} \ No newline at end of file diff --git a/OnnxStack.UI/OnnxStack.UI.csproj b/OnnxStack.UI/OnnxStack.UI.csproj index 030dd0f7..86fe4918 100644 --- a/OnnxStack.UI/OnnxStack.UI.csproj +++ b/OnnxStack.UI/OnnxStack.UI.csproj @@ -67,4 +67,13 @@ + + + PreserveNewest + + + PreserveNewest + + + diff --git a/OnnxStack.UI/UserControls/PromptControl.xaml b/OnnxStack.UI/UserControls/PromptControl.xaml index 04f8ace6..2ee20b78 100644 --- a/OnnxStack.UI/UserControls/PromptControl.xaml +++ b/OnnxStack.UI/UserControls/PromptControl.xaml @@ -13,7 +13,18 @@ - + + + + + + + + + + + + @@ -23,5 +34,6 @@ + diff --git a/OnnxStack.UI/UserControls/PromptControl.xaml.cs b/OnnxStack.UI/UserControls/PromptControl.xaml.cs index 99c6b4d4..e2f7a520 100644 --- a/OnnxStack.UI/UserControls/PromptControl.xaml.cs +++ b/OnnxStack.UI/UserControls/PromptControl.xaml.cs @@ -1,4 +1,5 @@ -using OnnxStack.UI.Commands; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.UI.Commands; using OnnxStack.UI.Models; using System.ComponentModel; using System.Runtime.CompilerServices; @@ -42,6 +43,18 @@ public PromptOptionsModel PromptOptions DependencyProperty.Register("PromptOptions", typeof(PromptOptionsModel), typeof(PromptControl)); + public bool IsVideoControlsEnabled + { + get { return (bool)GetValue(IsVideoControlsEnabledProperty); } + set { SetValue(IsVideoControlsEnabledProperty, value); } + } + public static readonly DependencyProperty IsVideoControlsEnabledProperty = + DependencyProperty.Register("IsVideoControlsEnabled", typeof(bool), typeof(PromptControl)); + + + + + public StableDiffusionModelSetViewModel SelectedModel { get { return (StableDiffusionModelSetViewModel)GetValue(SelectedModelProperty); } diff --git a/OnnxStack.UI/UserControls/VideoInputControl.xaml b/OnnxStack.UI/UserControls/VideoInputControl.xaml new file mode 100644 index 00000000..5ac1b3ef --- /dev/null +++ b/OnnxStack.UI/UserControls/VideoInputControl.xaml @@ -0,0 +1,78 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/OnnxStack.UI/UserControls/VideoInputControl.xaml.cs b/OnnxStack.UI/UserControls/VideoInputControl.xaml.cs new file mode 100644 index 00000000..3b2e77f5 --- /dev/null +++ b/OnnxStack.UI/UserControls/VideoInputControl.xaml.cs @@ -0,0 +1,182 @@ +using Microsoft.Win32; +using OnnxStack.Core.Services; +using OnnxStack.UI.Commands; +using OnnxStack.UI.Models; +using System; +using System.ComponentModel; +using System.IO; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using System.Windows; +using System.Windows.Controls; +using System.Windows.Media.Imaging; + +namespace OnnxStack.UI.UserControls +{ + public partial class VideoInputControl : UserControl, INotifyPropertyChanged + { + private readonly IVideoService _videoService; + private bool _isPlaying = false; + + /// + /// Initializes a new instance of the class. + /// + public VideoInputControl() + { + if (!DesignerProperties.GetIsInDesignMode(this)) + _videoService = App.GetService(); + + LoadVideoCommand = new AsyncRelayCommand(LoadVideo); + ClearVideoCommand = new AsyncRelayCommand(ClearVideo); + InitializeComponent(); + } + + public AsyncRelayCommand LoadVideoCommand { get; } + public AsyncRelayCommand ClearVideoCommand { get; } + + public VideoInputModel VideoResult + { + get { return (VideoInputModel)GetValue(VideoResultProperty); } + set { SetValue(VideoResultProperty, value); } + } + public static readonly DependencyProperty VideoResultProperty = + DependencyProperty.Register("VideoResult", typeof(VideoInputModel), typeof(VideoInputControl)); + + public SchedulerOptionsModel SchedulerOptions + { + get { return (SchedulerOptionsModel)GetValue(SchedulerOptionsProperty); } + set { SetValue(SchedulerOptionsProperty, value); } + } + public static readonly DependencyProperty SchedulerOptionsProperty = + DependencyProperty.Register("SchedulerOptions", typeof(SchedulerOptionsModel), typeof(VideoInputControl)); + + public PromptOptionsModel PromptOptions + { + get { return (PromptOptionsModel)GetValue(PromptOptionsProperty); } + set { SetValue(PromptOptionsProperty, value); } + } + public static readonly DependencyProperty PromptOptionsProperty = + DependencyProperty.Register("PromptOptions", typeof(PromptOptionsModel), typeof(VideoInputControl)); + + public bool IsGenerating + { + get { return (bool)GetValue(IsGeneratingProperty); } + set { SetValue(IsGeneratingProperty, value); } + } + public static readonly DependencyProperty IsGeneratingProperty = + DependencyProperty.Register("IsGenerating", typeof(bool), typeof(VideoInputControl)); + + public bool HasVideoResult + { + get { return (bool)GetValue(HasVideoResultProperty); } + set { SetValue(HasVideoResultProperty, value); } + } + public static readonly DependencyProperty HasVideoResultProperty = + DependencyProperty.Register("HasVideoResult", typeof(bool), typeof(VideoInputControl)); + + public BitmapImage PreviewImage + { + get { return (BitmapImage)GetValue(PreviewImageProperty); } + set { SetValue(PreviewImageProperty, value); } + } + public static readonly DependencyProperty PreviewImageProperty = + DependencyProperty.Register("PreviewImage", typeof(BitmapImage), typeof(VideoInputControl)); + + + /// + /// Loads the image. + /// + /// + private async Task LoadVideo() + { + // Show Dialog + var openFileDialog = new OpenFileDialog + { + Filter = "Video Files|*.mp4;*.avi;*.mkv;*.mov;*.wmv;*.flv;*.gif|All Files|*.*", + RestoreDirectory = true, + Multiselect = false, + }; + if (openFileDialog.ShowDialog() == true) + { + var videoBytes = await File.ReadAllBytesAsync(openFileDialog.FileName); + var videoInfo = await _videoService.GetVideoInfoAsync(videoBytes); + VideoResult = new VideoInputModel + { + FileName = openFileDialog.FileName, + VideoInfo = videoInfo, + VideoBytes = videoBytes + }; + HasVideoResult = true; + PromptOptions.VideoInputFPS = videoInfo.FPS; + PromptOptions.VideoOutputFPS = videoInfo.FPS; + } + } + + + /// + /// Clears the image. + /// + /// + private Task ClearVideo() + { + VideoResult = null; + HasVideoResult = false; + return Task.CompletedTask; + } + + + /// + /// Handles the Loaded event of the MediaElement control. + /// + /// The source of the event. + /// The instance containing the event data. + private void MediaElement_Loaded(object sender, RoutedEventArgs e) + { + (sender as MediaElement).Play(); + _isPlaying = true; + } + + + /// + /// Handles the MediaEnded event of the MediaElement control. + /// + /// The source of the event. + /// The instance containing the event data. + private void MediaElement_MediaEnded(object sender, RoutedEventArgs e) + { + (sender as MediaElement).Position = TimeSpan.FromMilliseconds(1); + } + + + /// + /// Handles the MouseDown event of the MediaElement control. + /// + /// The source of the event. + /// The instance containing the event data. + private void MediaElement_MouseDown(object sender, System.Windows.Input.MouseButtonEventArgs e) + { + if (sender is not MediaElement mediaElement) + return; + + if (_isPlaying) + { + _isPlaying = false; + mediaElement.Pause(); + return; + } + + mediaElement.Play(); + _isPlaying = true; + } + + + #region INotifyPropertyChanged + public event PropertyChangedEventHandler PropertyChanged; + public void NotifyPropertyChanged([CallerMemberName] string property = "") + { + PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(property)); + } + + #endregion + } +} diff --git a/OnnxStack.UI/UserControls/VideoResultControl.xaml b/OnnxStack.UI/UserControls/VideoResultControl.xaml new file mode 100644 index 00000000..ae543364 --- /dev/null +++ b/OnnxStack.UI/UserControls/VideoResultControl.xaml @@ -0,0 +1,88 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/OnnxStack.UI/UserControls/VideoResultControl.xaml.cs b/OnnxStack.UI/UserControls/VideoResultControl.xaml.cs new file mode 100644 index 00000000..7b9568e2 --- /dev/null +++ b/OnnxStack.UI/UserControls/VideoResultControl.xaml.cs @@ -0,0 +1,206 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Win32; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.UI.Commands; +using OnnxStack.UI.Models; +using System; +using System.ComponentModel; +using System.IO; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using System.Windows; +using System.Windows.Controls; +using System.Windows.Media.Imaging; + +namespace OnnxStack.UI.UserControls +{ + public partial class VideoResultControl : UserControl, INotifyPropertyChanged + { + private readonly ILogger _logger; + private bool _isPlaying = false; + + /// + /// Initializes a new instance of the class. + /// + public VideoResultControl() + { + if (!DesignerProperties.GetIsInDesignMode(this)) + _logger = App.GetService>(); + + SaveVideoCommand = new AsyncRelayCommand(SaveVideo); + ClearVideoCommand = new AsyncRelayCommand(ClearVideo); + InitializeComponent(); + HasVideoResult = false; + } + + public AsyncRelayCommand SaveVideoCommand { get; } + public AsyncRelayCommand ClearVideoCommand { get; } + + public VideoInputModel VideoResult + { + get { return (VideoInputModel)GetValue(VideoResultProperty); } + set { SetValue(VideoResultProperty, value); } + } + public static readonly DependencyProperty VideoResultProperty = + DependencyProperty.Register("VideoResult", typeof(VideoInputModel), typeof(VideoResultControl)); + + public SchedulerOptionsModel SchedulerOptions + { + get { return (SchedulerOptionsModel)GetValue(SchedulerOptionsProperty); } + set { SetValue(SchedulerOptionsProperty, value); } + } + public static readonly DependencyProperty SchedulerOptionsProperty = + DependencyProperty.Register("SchedulerOptions", typeof(SchedulerOptionsModel), typeof(VideoResultControl)); + + public bool HasVideoResult + { + get { return (bool)GetValue(HasVideoResultProperty); } + set { SetValue(HasVideoResultProperty, value); } + } + public static readonly DependencyProperty HasVideoResultProperty = + DependencyProperty.Register("HasVideoResult", typeof(bool), typeof(VideoResultControl)); + + public int ProgressMax + { + get { return (int)GetValue(ProgressMaxProperty); } + set { SetValue(ProgressMaxProperty, value); } + } + public static readonly DependencyProperty ProgressMaxProperty = + DependencyProperty.Register("ProgressMax", typeof(int), typeof(VideoResultControl)); + + public int ProgressValue + { + get { return (int)GetValue(ProgressValueProperty); } + set { SetValue(ProgressValueProperty, value); } + } + public static readonly DependencyProperty ProgressValueProperty = + DependencyProperty.Register("ProgressValue", typeof(int), typeof(VideoResultControl)); + + public string ProgressText + { + get { return (string)GetValue(ProgressTextProperty); } + set { SetValue(ProgressTextProperty, value); } + } + public static readonly DependencyProperty ProgressTextProperty = + DependencyProperty.Register("ProgressText", typeof(string), typeof(VideoResultControl)); + + public BitmapImage PreviewImage + { + get { return (BitmapImage)GetValue(PreviewImageProperty); } + set { SetValue(PreviewImageProperty, value); } + } + public static readonly DependencyProperty PreviewImageProperty = + DependencyProperty.Register("PreviewImage", typeof(BitmapImage), typeof(VideoResultControl)); + + + + /// + /// Saves the video. + /// + private async Task SaveVideo() + { + await SaveVideoFile(VideoResult, SchedulerOptions); + } + + /// + /// Clears the image. + /// + /// + private Task ClearVideo() + { + ProgressMax = 1; + VideoResult = null; + HasVideoResult = false; + return Task.CompletedTask; + } + + + /// + /// Saves the video file. + /// + /// The video result. + /// The scheduler options. + private async Task SaveVideoFile(VideoInputModel videoResult, SchedulerOptionsModel schedulerOptions) + { + try + { + var saveFileDialog = new SaveFileDialog + { + Filter = "mp4 files (*.mp4)|*.mp4", + DefaultExt = "mp4", + AddExtension = true, + RestoreDirectory = true, + InitialDirectory = Environment.GetFolderPath(Environment.SpecialFolder.MyPictures), + FileName = $"video-{schedulerOptions.Seed}.mp4" + }; + + var dialogResult = saveFileDialog.ShowDialog(); + if (dialogResult == false) + { + _logger.LogInformation("Saving video canceled"); + return; + } + + // Write File + await File.WriteAllBytesAsync(saveFileDialog.FileName, videoResult.VideoBytes); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error saving video"); + } + } + + + /// + /// Handles the Loaded event of the MediaElement control. + /// + /// The source of the event. + /// The instance containing the event data. + private void MediaElement_Loaded(object sender, RoutedEventArgs e) + { + (sender as MediaElement).Play(); + _isPlaying = true; + } + + + /// + /// Handles the MediaEnded event of the MediaElement control. + /// + /// The source of the event. + /// The instance containing the event data. + private void MediaElement_MediaEnded(object sender, RoutedEventArgs e) + { + (sender as MediaElement).Position = TimeSpan.FromMilliseconds(1); + } + + + /// + /// Handles the MouseDown event of the MediaElement control. + /// + /// The source of the event. + /// The instance containing the event data. + private void MediaElement_MouseDown(object sender, System.Windows.Input.MouseButtonEventArgs e) + { + if (sender is not MediaElement mediaElement) + return; + + if (_isPlaying) + { + _isPlaying = false; + mediaElement.Pause(); + return; + } + + mediaElement.Play(); + _isPlaying = true; + } + + #region INotifyPropertyChanged + public event PropertyChangedEventHandler PropertyChanged; + public void NotifyPropertyChanged([CallerMemberName] string property = "") + { + PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(property)); + } + #endregion + } +} diff --git a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs index 251b534d..29af6d80 100644 --- a/OnnxStack.UI/Views/ImageInpaintView.xaml.cs +++ b/OnnxStack.UI/Views/ImageInpaintView.xaml.cs @@ -383,10 +383,10 @@ private Action ProgressCallback() if (_cancelationTokenSource.IsCancellationRequested) return; - if (ProgressMax != progress.ProgressMax) - ProgressMax = progress.ProgressMax; + if (ProgressMax != progress.StepMax) + ProgressMax = progress.StepMax; - ProgressValue = progress.ProgressValue; + ProgressValue = progress.StepValue; }); }; } diff --git a/OnnxStack.UI/Views/ImageToImageView.xaml.cs b/OnnxStack.UI/Views/ImageToImageView.xaml.cs index cfab7ac7..6b93e308 100644 --- a/OnnxStack.UI/Views/ImageToImageView.xaml.cs +++ b/OnnxStack.UI/Views/ImageToImageView.xaml.cs @@ -350,10 +350,10 @@ private Action ProgressCallback() if (_cancelationTokenSource.IsCancellationRequested) return; - if (ProgressMax != progress.ProgressMax) - ProgressMax = progress.ProgressMax; + if (ProgressMax != progress.StepMax) + ProgressMax = progress.StepMax; - ProgressValue = progress.ProgressValue; + ProgressValue = progress.StepValue; }); }; } diff --git a/OnnxStack.UI/Views/TextToImageView.xaml.cs b/OnnxStack.UI/Views/TextToImageView.xaml.cs index fb94673c..17b2b4ff 100644 --- a/OnnxStack.UI/Views/TextToImageView.xaml.cs +++ b/OnnxStack.UI/Views/TextToImageView.xaml.cs @@ -325,10 +325,10 @@ private Action ProgressCallback() if (_cancelationTokenSource.IsCancellationRequested) return; - if (ProgressMax != progress.ProgressMax) - ProgressMax = progress.ProgressMax; + if (ProgressMax != progress.StepMax) + ProgressMax = progress.StepMax; - ProgressValue = progress.ProgressValue; + ProgressValue = progress.StepValue; }); }; } diff --git a/OnnxStack.UI/Views/VideoToVideoView.xaml b/OnnxStack.UI/Views/VideoToVideoView.xaml new file mode 100644 index 00000000..6e7ab4af --- /dev/null +++ b/OnnxStack.UI/Views/VideoToVideoView.xaml @@ -0,0 +1,117 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/OnnxStack.UI/Views/VideoToVideoView.xaml.cs b/OnnxStack.UI/Views/VideoToVideoView.xaml.cs new file mode 100644 index 00000000..fedd2f71 --- /dev/null +++ b/OnnxStack.UI/Views/VideoToVideoView.xaml.cs @@ -0,0 +1,398 @@ +using Microsoft.Extensions.Logging; +using OnnxStack.Core.Image; +using OnnxStack.Core.Services; +using OnnxStack.Core.Video; +using OnnxStack.StableDiffusion.Common; +using OnnxStack.StableDiffusion.Config; +using OnnxStack.StableDiffusion.Enums; +using OnnxStack.StableDiffusion.Models; +using OnnxStack.UI.Commands; +using OnnxStack.UI.Models; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.ComponentModel; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using System.Windows; +using System.Windows.Controls; +using System.Windows.Media.Imaging; +using System.Windows.Threading; + +namespace OnnxStack.UI.Views +{ + /// + /// Interaction logic for VideoToVideoView.xaml + /// + public partial class VideoToVideoView : UserControl, INavigatable, INotifyPropertyChanged + { + private readonly ILogger _logger; + private readonly IStableDiffusionService _stableDiffusionService; + private readonly IVideoService _videoService; + + private bool _hasResult; + private int _progressMax; + private int _progressValue; + private string _progressText; + private bool _isGenerating; + private int _selectedTabIndex; + private bool _hasInputResult; + private bool _isControlsEnabled; + private VideoInputModel _inputVideo; + private VideoInputModel _resultVideo; + private StableDiffusionModelSetViewModel _selectedModel; + private PromptOptionsModel _promptOptionsModel; + private SchedulerOptionsModel _schedulerOptions; + private CancellationTokenSource _cancelationTokenSource; + private VideoFrames _videoFrames; + private BitmapImage _previewSource; + private BitmapImage _previewResult; + + /// + /// Initializes a new instance of the class. + /// + public VideoToVideoView() + { + if (!DesignerProperties.GetIsInDesignMode(this)) + { + _logger = App.GetService>(); + _videoService = App.GetService(); + _stableDiffusionService = App.GetService(); + } + + SupportedDiffusers = new() { DiffuserType.ImageToImage }; + CancelCommand = new AsyncRelayCommand(Cancel, CanExecuteCancel); + GenerateCommand = new AsyncRelayCommand(Generate, CanExecuteGenerate); + ClearHistoryCommand = new AsyncRelayCommand(ClearHistory, CanExecuteClearHistory); + PromptOptions = new PromptOptionsModel(); + SchedulerOptions = new SchedulerOptionsModel(); + ImageResults = new ObservableCollection(); + ProgressMax = SchedulerOptions.InferenceSteps; + IsControlsEnabled = true; + InitializeComponent(); + } + + public OnnxStackUIConfig UISettings + { + get { return (OnnxStackUIConfig)GetValue(UISettingsProperty); } + set { SetValue(UISettingsProperty, value); } + } + public static readonly DependencyProperty UISettingsProperty = + DependencyProperty.Register("UISettings", typeof(OnnxStackUIConfig), typeof(VideoToVideoView)); + + public List SupportedDiffusers { get; } + public AsyncRelayCommand CancelCommand { get; } + public AsyncRelayCommand GenerateCommand { get; } + public AsyncRelayCommand ClearHistoryCommand { get; set; } + public ObservableCollection ImageResults { get; } + + public StableDiffusionModelSetViewModel SelectedModel + { + get { return _selectedModel; } + set { _selectedModel = value; NotifyPropertyChanged(); } + } + + public PromptOptionsModel PromptOptions + { + get { return _promptOptionsModel; } + set { _promptOptionsModel = value; NotifyPropertyChanged(); } + } + + public SchedulerOptionsModel SchedulerOptions + { + get { return _schedulerOptions; } + set { _schedulerOptions = value; NotifyPropertyChanged(); } + } + + public VideoInputModel ResultVideo + { + get { return _resultVideo; } + set { _resultVideo = value; NotifyPropertyChanged(); } + } + + public VideoInputModel InputVideo + { + get { return _inputVideo; } + set { _inputVideo = value; _videoFrames = null; NotifyPropertyChanged(); } + } + + public int ProgressValue + { + get { return _progressValue; } + set { _progressValue = value; NotifyPropertyChanged(); } + } + + public int ProgressMax + { + get { return _progressMax; } + set { _progressMax = value; NotifyPropertyChanged(); } + } + + public string ProgressText + { + get { return _progressText; } + set { _progressText = value; NotifyPropertyChanged(); } + } + + public bool IsGenerating + { + get { return _isGenerating; } + set { _isGenerating = value; NotifyPropertyChanged(); } + } + + public bool HasResult + { + get { return _hasResult; } + set { _hasResult = value; NotifyPropertyChanged(); } + } + + public bool HasInputResult + { + get { return _hasInputResult; } + set { _hasInputResult = value; NotifyPropertyChanged(); } + } + + public int SelectedTabIndex + { + get { return _selectedTabIndex; } + set { _selectedTabIndex = value; NotifyPropertyChanged(); } + } + + public bool IsControlsEnabled + { + get { return _isControlsEnabled; } + set { _isControlsEnabled = value; NotifyPropertyChanged(); } + } + + public BitmapImage PreviewSource + { + get { return _previewSource; } + set { _previewSource = value; NotifyPropertyChanged(); } + } + + public BitmapImage PreviewResult + { + get { return _previewResult; } + set { _previewResult = value; NotifyPropertyChanged(); } + } + + + /// + /// Called on Navigate + /// + /// The image result. + /// + public async Task NavigateAsync(ImageResult imageResult) + { + throw new NotImplementedException(); + } + + + + /// + /// Generates this image result. + /// + private async Task Generate() + { + HasResult = false; + IsGenerating = true; + IsControlsEnabled = false; + ResultVideo = null; + ProgressMax = 0; + _cancelationTokenSource = new CancellationTokenSource(); + + try + { + var schedulerOptions = SchedulerOptions.ToSchedulerOptions(); + if (_videoFrames is null || _videoFrames.Info.FPS != PromptOptions.VideoInputFPS) + { + ProgressText = $"Generating video frames @ {PromptOptions.VideoInputFPS}fps"; + _videoFrames = await _videoService.CreateFramesAsync(_inputVideo.VideoBytes, PromptOptions.VideoInputFPS, _cancelationTokenSource.Token); + } + var promptOptions = GetPromptOptions(PromptOptions, _videoFrames); + + var timestamp = Stopwatch.GetTimestamp(); + var result = await _stableDiffusionService.GenerateAsBytesAsync(_selectedModel.ModelSet, promptOptions, schedulerOptions, ProgressCallback(), _cancelationTokenSource.Token); + var resultVideo = await GenerateResultAsync(result, promptOptions, schedulerOptions, timestamp); + if (resultVideo != null) + { + App.UIInvoke(() => + { + ResultVideo = resultVideo; + HasResult = true; + }); + } + } + catch (OperationCanceledException) + { + _logger.LogInformation($"Generate was canceled."); + } + catch (Exception ex) + { + _logger.LogError($"Error during Generate\n{ex}"); + } + + Reset(); + } + + + /// + /// Determines whether this instance can execute Generate. + /// + /// + /// true if this instance can execute Generate; otherwise, false. + /// + private bool CanExecuteGenerate() + { + return !IsGenerating && HasInputResult; + } + + + /// + /// Cancels this generation. + /// + /// + private Task Cancel() + { + _cancelationTokenSource?.Cancel(); + return Task.CompletedTask; + } + + + /// + /// Determines whether this instance can execute Cancel. + /// + /// + /// true if this instance can execute Cancel; otherwise, false. + /// + private bool CanExecuteCancel() + { + return IsGenerating; + } + + + /// + /// Clears the history. + /// + /// + private Task ClearHistory() + { + ImageResults.Clear(); + return Task.CompletedTask; + } + + + /// + /// Determines whether this instance can execute ClearHistory. + /// + /// + /// true if this instance can execute ClearHistory; otherwise, false. + /// + private bool CanExecuteClearHistory() + { + return ImageResults.Count > 0; + } + + + /// + /// Resets this instance. + /// + private void Reset() + { + PreviewSource = null; + PreviewResult = null; + IsGenerating = false; + IsControlsEnabled = true; + ProgressValue = 0; + ProgressMax = 1; + ProgressText = null; + } + + + private PromptOptions GetPromptOptions(PromptOptionsModel promptOptionsModel, VideoFrames videoFrames) + { + return new PromptOptions + { + Prompt = promptOptionsModel.Prompt, + NegativePrompt = promptOptionsModel.NegativePrompt, + DiffuserType = DiffuserType.ImageToImage, + InputVideo = new VideoInput(videoFrames), + VideoInputFPS = promptOptionsModel.VideoInputFPS, + VideoOutputFPS = promptOptionsModel.VideoOutputFPS, + }; + } + + + /// + /// Generates the result. + /// + /// The image bytes. + /// The prompt options. + /// The scheduler options. + /// The timestamp. + /// + private async Task GenerateResultAsync(byte[] videoBytes, PromptOptions promptOptions, SchedulerOptions schedulerOptions, long timestamp) + { + var tempVideoFile = Path.Combine(".temp", $"VideoToVideo.mp4"); + await File.WriteAllBytesAsync(tempVideoFile, videoBytes); + var videoInfo = await _videoService.GetVideoInfoAsync(videoBytes); + var videoResult = new VideoInputModel + { + FileName = tempVideoFile, + VideoInfo = videoInfo, + VideoBytes = videoBytes + }; + return videoResult; + } + + + /// + /// StableDiffusion progress callback. + /// + /// + private Action ProgressCallback() + { + return (progress) => + { + if (_cancelationTokenSource.IsCancellationRequested) + return; + + App.UIInvoke(() => + { + if (_cancelationTokenSource.IsCancellationRequested) + return; + + if (progress.BatchTensor is not null) + { + PreviewResult = Utils.CreateBitmap(progress.BatchTensor.ToImageBytes()); + PreviewSource = Utils.CreateBitmap(_videoFrames.Frames[progress.BatchValue - 1]); + ProgressText = $"Video Frame {progress.BatchValue} of {_videoFrames.Frames.Count} complete"; + } + + if (ProgressText != progress.Message && progress.BatchMax == 0) + ProgressText = progress.Message; + + if (ProgressMax != progress.BatchMax) + ProgressMax = progress.BatchMax; + + if (ProgressValue != progress.BatchValue) + ProgressValue = progress.BatchValue; + + }, DispatcherPriority.Background); + }; + } + + #region INotifyPropertyChanged + public event PropertyChangedEventHandler PropertyChanged; + public void NotifyPropertyChanged([CallerMemberName] string property = "") + { + PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(property)); + } + + #endregion + } + +} \ No newline at end of file From ddfb9c892b258ca41c43462be635685a8bbe7dec Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Mon, 1 Jan 2024 19:43:29 +1300 Subject: [PATCH 8/8] Remove FFMPEG binaries --- OnnxStack.UI/OnnxStack.UI.csproj | 9 --------- 1 file changed, 9 deletions(-) diff --git a/OnnxStack.UI/OnnxStack.UI.csproj b/OnnxStack.UI/OnnxStack.UI.csproj index 86fe4918..030dd0f7 100644 --- a/OnnxStack.UI/OnnxStack.UI.csproj +++ b/OnnxStack.UI/OnnxStack.UI.csproj @@ -67,13 +67,4 @@ - - - PreserveNewest - - - PreserveNewest - - -