diff --git a/Extension.cs b/Extension.cs index 954e99f..773464c 100644 --- a/Extension.cs +++ b/Extension.cs @@ -49,6 +49,23 @@ public static Dictionary GetSizeForEachDynamicLayerInBytes(this nn } } + public static void ToQuantizedModule( + this T model) + where T : nn.Module + { + foreach (var (_, value) in model.named_children()) + { + if (value is IQuantizeModule quantizeModule) + { + quantizeModule.Quantize(); + } + else + { + value.ToQuantizedModule(); + } + } + } + public static T ToDynamicLoadingModel( this T model, Dictionary deviceMap, diff --git a/Module/IDynamicLoadModule.cs b/Module/IDynamicLoadModule.cs index 25ba6cb..f5bc8b7 100644 --- a/Module/IDynamicLoadModule.cs +++ b/Module/IDynamicLoadModule.cs @@ -13,3 +13,8 @@ public interface IDynamicLoadModule public Action? UnloadFromDeviceFunc { get; set; } } + +public interface IQuantizeModule +{ + public void Quantize(); +} diff --git a/Module/Phi3Attention.cs b/Module/Phi3Attention.cs index 74826de..4c0c9ae 100644 --- a/Module/Phi3Attention.cs +++ b/Module/Phi3Attention.cs @@ -1,4 +1,5 @@ using FluentAssertions; +using Phi.Module; using System; using System.Collections.Generic; using System.Linq; @@ -71,8 +72,8 @@ public class Phi3Attention : nn.Module private readonly Dictionary? rope_scaling; private readonly bool is_causal; - private readonly Linear o_proj; - private readonly Linear qkv_proj; + private readonly PhiLinear o_proj; + private readonly PhiLinear qkv_proj; private nn.Module rotary_emb; public Phi3Attention(Phi3Config config, int layer_idx) @@ -95,8 +96,8 @@ public Phi3Attention(Phi3Config config, int layer_idx) (this.head_dim * this.num_heads).Should().Be(this.hidden_size, "hidden_size must be divisible by num_heads"); var op_size = this.num_heads * this.head_dim + 2 * (this.num_key_value_heads * this.head_dim); - this.o_proj = nn.Linear(this.num_heads * this.head_dim, this.hidden_size, hasBias: false, dtype: config.DType); - this.qkv_proj = nn.Linear(this.hidden_size, op_size, hasBias: false, dtype: config.DType); + this.o_proj = new PhiInt4Linear(this.num_heads * this.head_dim, this.hidden_size, hasBias: false, dtype: config.DType); + this.qkv_proj = new PhiInt4Linear(this.hidden_size, op_size, hasBias: false, dtype: config.DType); this._init_rope(); } diff --git a/Module/Phi3MLP.cs b/Module/Phi3MLP.cs index f9e5c66..4a1b074 100644 --- a/Module/Phi3MLP.cs +++ b/Module/Phi3MLP.cs @@ -6,11 +6,13 @@ using static TorchSharp.torch; using TorchSharp.Modules; using TorchSharp; +using Phi.Module; +using Phi.Tests; public class Phi3MLP : torch.nn.Module { - private readonly Linear gate_up_proj; - private readonly Linear down_proj; + private readonly PhiLinear gate_up_proj; + private readonly PhiLinear down_proj; private readonly torch.nn.Module activation_fn; public Phi3MLP(Phi3Config config) @@ -21,8 +23,8 @@ public Phi3MLP(Phi3Config config) public Phi3MLP(int hiddenSize, int intermediateSize, string hiddenAct, ScalarType dtype) : base(nameof(Phi3MLP)) { - this.gate_up_proj = torch.nn.Linear(hiddenSize, 2 * intermediateSize, hasBias: false, dtype: dtype); - this.down_proj = torch.nn.Linear(intermediateSize, hiddenSize, hasBias: false, dtype: dtype); + this.gate_up_proj = new PhiInt4Linear(hiddenSize, 2 * intermediateSize, hasBias: false, dtype: dtype); + this.down_proj = new PhiInt4Linear(intermediateSize, hiddenSize, hasBias: false, dtype: dtype); this.RegisterComponents(); this.activation_fn = Utils.GetActivation(hiddenAct); } diff --git a/Module/PhiInt4Linear.cs b/Module/PhiInt4Linear.cs new file mode 100644 index 0000000..594a7ef --- /dev/null +++ b/Module/PhiInt4Linear.cs @@ -0,0 +1,106 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TorchSharp; +using static TorchSharp.torch; + +namespace Phi.Module; + +public class PhiInt4Linear : PhiLinear, IQuantizeModule +{ + public PhiInt4Linear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null) + : base(inFeatures, outFeatures, hasBias, dtype, device) + { + } + + public void Quantize() + { + using var _ = NewDisposeScope(); + var timer = new System.Diagnostics.Stopwatch(); + Console.WriteLine("Quantize start"); + timer.Start(); + // scale and zero point on vector-wise + // scale = 15 / max(weight, axis=1) - min(weight, axis=1) + var scale = 15 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values); + + // zero point = - scale * min(weight, axis=1) - 8 + var zeroPoint = - scale * torch.min(this.weight, 1).values - 8; + // round zero point to nearest integer + zeroPoint = torch.round(zeroPoint); + var _4bitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8); + + zeroPoint = (zeroPoint + 8).to(torch.uint8); + _4bitWeight = (_4bitWeight + 8).view(-1).to(torch.uint8); + + // torch doesn't provide int4, so we use int8 as placeholder + // and foreach int8, we save two int4, e.g. 0b1010 -> 0b10, 0b10 + var placeHolderDim = this.outFeatures / 2 + this.outFeatures % 2; + var zpPlaceHolder = zeroPoint[..placeHolderDim]; + zpPlaceHolder = zpPlaceHolder * 16 + zeroPoint[placeHolderDim..]; + + // assert zero point is in range [-128, 127] + //if (torch.any(this.zeroPoint < -128).item() || torch.any(this.zeroPoint > 127).item()) + //{ + // throw new Exception("Zero point is out of range [-128, 127]"); + //} + + // quantize weight + var _4bitWeightPlaceHolderDim =Convert.ToInt32(_4bitWeight.size(0) / 2 + _4bitWeight.size(0) % 2); + var _4bitWeightPlaceHolder = _4bitWeight[.._4bitWeightPlaceHolderDim]; + _4bitWeightPlaceHolder = _4bitWeightPlaceHolder * 16 + _4bitWeight[_4bitWeightPlaceHolderDim..]; + + // assert weight is in range [-128, 127] + //if (torch.any(this._8bitWeight < -128).item() || torch.any(this._8bitWeight > 127).item()) + //{ + // throw new Exception("Weight is out of range [-128, 127]"); + //} + + // dispose float32 weight + this.weight.Dispose(); + this.weight = null; + + this._internal_buffers.Remove("weight"); + this.register_buffer("4bit_weight", _4bitWeightPlaceHolder.MoveToOuterDisposeScope()); + this.register_buffer("zeroPoint", zpPlaceHolder.MoveToOuterDisposeScope()); + this.register_buffer("scale", scale.MoveToOuterDisposeScope()); + timer.Stop(); + Console.WriteLine($"Quantize end, elapsed time: {timer.ElapsedMilliseconds} ms"); + } + + public override Tensor forward(Tensor input) + { + if (this._internal_buffers.ContainsKey("weight")) + { + return base.forward(input); + } + else + { + using var dispose = torch.NewDisposeScope(); + var weight = this.get_buffer("4bit_weight"); + var weightLower = weight % 16; + var weightUpper = weight / 16; + weight = torch.cat([weightUpper, weightLower], 0).to(ScalarType.Float32); + weight = weight.view(this.outFeatures, this.inFeatures); + weight -= 8; + var zeroPoint = this.get_buffer("zeroPoint"); + var zeroPointLower = zeroPoint % 16; + var zeroPointUpper = zeroPoint / 16; + zeroPoint = torch.cat([zeroPointUpper, zeroPointLower], 0).to(ScalarType.Float32); + zeroPoint -= 8; + var scale = this.get_buffer("scale").to(ScalarType.Float32); + var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1); + // use float32 + var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T); + + if (this.bias is not null) + { + result = result + this.bias.to_type(ScalarType.Float32); + } + + //result.Peek("result"); + return result.to_type(input.dtype).MoveToOuterDisposeScope(); + } + } +} diff --git a/Module/PhiInt8Linear.cs b/Module/PhiInt8Linear.cs new file mode 100644 index 0000000..ac8eb02 --- /dev/null +++ b/Module/PhiInt8Linear.cs @@ -0,0 +1,88 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TorchSharp; +using static TorchSharp.torch; + +namespace Phi.Module; + +public class PhiInt8Linear : PhiLinear, IQuantizeModule +{ + //private Tensor? scale; + //private Tensor? zeroPoint; + //private Tensor? _8bitWeight; + + public PhiInt8Linear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null) + : base(inFeatures, outFeatures, hasBias, dtype, device) + { + } + + public void Quantize() + { + var timer = new System.Diagnostics.Stopwatch(); + Console.WriteLine("Quantize start"); + timer.Start(); + // scale and zero point on vector-wise + // scale = 255 / max(weight, axis=1) - min(weight, axis=1) + var scale = 255 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values); + + // zero point = - scale * min(weight, axis=1) - 128 + var zeroPoint = - scale * torch.min(this.weight, 1).values - 128; + // round zero point to nearest integer + zeroPoint = torch.round(zeroPoint).to(torch.int8); + + // assert zero point is in range [-128, 127] + //if (torch.any(this.zeroPoint < -128).item() || torch.any(this.zeroPoint > 127).item()) + //{ + // throw new Exception("Zero point is out of range [-128, 127]"); + //} + + // quantize weight + var _8bitWeight = torch.round(this.weight * scale.view(-1, 1)+ zeroPoint.view(-1, 1)).to(torch.int8); + + // assert weight is in range [-128, 127] + //if (torch.any(this._8bitWeight < -128).item() || torch.any(this._8bitWeight > 127).item()) + //{ + // throw new Exception("Weight is out of range [-128, 127]"); + //} + + // dispose float32 weight + this.weight.Dispose(); + this.weight = null; + + this._internal_buffers.Remove("weight"); + this.register_buffer("8bit_weight", _8bitWeight); + this.register_buffer("zeroPoint", zeroPoint); + this.register_buffer("scale", scale); + timer.Stop(); + Console.WriteLine($"Quantize end, elapsed time: {timer.ElapsedMilliseconds} ms"); + } + + public override Tensor forward(Tensor input) + { + if (this._internal_buffers.ContainsKey("weight")) + { + return base.forward(input); + } + else + { + using var dispose = torch.NewDisposeScope(); + var weight = this.get_buffer("8bit_weight").to(ScalarType.Float32); + var zeroPoint = this.get_buffer("zeroPoint").to(ScalarType.Float32); + var scale = this.get_buffer("scale").to(ScalarType.Float32); + var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1); + // use float32 + var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T); + + if (this.bias is not null) + { + result = result + this.bias.to_type(ScalarType.Float32); + } + + //result.Peek("result"); + return result.to_type(input.dtype).MoveToOuterDisposeScope(); + } + } +} diff --git a/Module/PhiLinear.cs b/Module/PhiLinear.cs index fe0723e..e75c4dd 100644 --- a/Module/PhiLinear.cs +++ b/Module/PhiLinear.cs @@ -3,16 +3,17 @@ public class PhiLinear : nn.Module { - private readonly Tensor weight; - private readonly Tensor? bias; - private int inFeatures; - private int outFeatures; + protected Tensor? weight; + protected readonly Tensor? bias; + protected int inFeatures; + protected int outFeatures; public PhiLinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null) : base(nameof(PhiLinear)) { this.inFeatures = inFeatures; this.outFeatures = outFeatures; + device = device ?? "cpu"; this.weight = torch.randn(outFeatures, inFeatures, dtype: dtype, device: device); if (hasBias) diff --git a/Phi3/Phi3ForCasualLM.cs b/Phi3/Phi3ForCasualLM.cs index 325f0a1..bbd779c 100644 --- a/Phi3/Phi3ForCasualLM.cs +++ b/Phi3/Phi3ForCasualLM.cs @@ -49,7 +49,7 @@ public static Phi3ForCasualLM FromPretrained( modelConfig.DType = torchDtype; var phi = new Phi3ForCasualLM(modelConfig); var loadedParameters = new Dictionary(); - phi.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedParameters); + phi.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedParameters, useTqdm: false); phi = phi.to(device); phi.eval(); diff --git a/Program.cs b/Program.cs index 0723c2a..f01b5a6 100644 --- a/Program.cs +++ b/Program.cs @@ -8,7 +8,7 @@ using TorchSharp; using static TorchSharp.torch; -var phi2Folder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; +var phiFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"; var device = "cpu"; if (device == "cuda") @@ -22,22 +22,22 @@ Console.WriteLine("Loading Phi3 from huggingface model weight folder"); var timer = System.Diagnostics.Stopwatch.StartNew(); -var model = Phi3ForCasualLM.FromPretrained(phi2Folder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json"); -var tokenizer = LLama2Tokenizer.FromPretrained(phi2Folder); +var model = Phi3ForCasualLM.FromPretrained(phiFolder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json"); +var tokenizer = LLama2Tokenizer.FromPretrained(phiFolder); var deviceSizeMap = new Dictionary { - ["cuda:0"] = 0L * 1024 * 1024 * 1024, + ["cuda:0"] = 5L * 1024 * 1024 * 1024, ["cpu"] = 64L * 1024 * 1024 * 1024, ["disk"] = 2L * 1024 * 1024 * 1024 * 1024, }; +model.ToQuantizedModule(); var deviceMap = model.InferDeviceMapForEachLayer( devices: [ "cuda:0", "cpu", "disk" ], deviceSizeMapInByte: deviceSizeMap); var json = JsonSerializer.Serialize(deviceMap, new JsonSerializerOptions { WriteIndented = true }); Console.WriteLine(json); - model = model.ToDynamicLoadingModel(deviceMap, "cuda:0"); var pipeline = new CasualLMPipeline(tokenizer, model, device); @@ -49,8 +49,8 @@ // agent var agent = new Phi3Agent(pipeline, "assistant") .RegisterPrintMessage(); -var question = @"count to 3"; -var systemMessage = new TextMessage(Role.System, "You are a helpful AI assistant that always respond in JSON format"); +var question = @"Use C# to calculate 100th fibonacci"; +var systemMessage = new TextMessage(Role.System, "You are a helpful AI assistant."); var userMessage = new TextMessage(Role.User, question); for (int i = 0; i!= 100; ++i) { diff --git a/Tests/Approvals/DynamicLoadingTest.ItInferDeviceMapTestsAsync.received.txt b/Tests/Approvals/DynamicLoadingTest.ItInferDeviceMapTestsAsync.received.txt index e350901..c2abecc 100644 --- a/Tests/Approvals/DynamicLoadingTest.ItInferDeviceMapTestsAsync.received.txt +++ b/Tests/Approvals/DynamicLoadingTest.ItInferDeviceMapTestsAsync.received.txt @@ -2,66 +2,66 @@ "linear1": "cuda:0", "linears.0": "cuda:0", "linears.1": "cuda:0", - "linears.2": "cuda:0", - "linears.3": "cuda:0", - "linears.4": "cuda:0", - "linears.5": "cuda:0", - "linears.6": "cuda:0", - "linears.7": "cuda:0", - "linears.8": "cuda:0", - "linears.9": "cuda:0", - "linears.10": "cuda:0", - "linears.11": "cuda:0", - "linears.12": "cuda:0", - "linears.13": "cuda:0", - "linears.14": "cuda:0", - "linears.15": "cuda:0", - "linears.16": "cuda:0", - "linears.17": "cuda:0", - "linears.18": "cuda:0", - "linears.19": "cuda:0", - "linears.20": "cuda:0", - "linears.21": "cuda:0", - "linears.22": "cuda:0", - "linears.23": "cuda:0", - "linears.24": "cuda:0", - "linears.25": "cuda:0", - "linears.26": "cuda:0", - "linears.27": "cuda:0", - "linears.28": "cuda:0", - "linears.29": "cuda:0", - "linears.30": "cuda:0", - "linears.31": "cuda:0", - "linears.32": "cuda:0", - "linears.33": "cuda:0", - "linears.34": "cuda:0", - "linears.35": "cuda:0", - "linears.36": "cuda:0", - "linears.37": "cuda:0", - "linears.38": "cuda:0", - "linears.39": "cuda:0", - "linears.40": "cuda:0", - "linears.41": "cuda:0", - "linears.42": "cuda:0", - "linears.43": "cuda:0", - "linears.44": "cuda:0", - "linears.45": "cuda:0", - "linears.46": "cuda:0", - "linears.47": "cuda:0", - "linears.48": "cuda:0", - "linears.49": "cuda:0", - "linears.50": "cuda:0", - "linears.51": "cuda:0", - "linears.52": "cuda:0", - "linears.53": "cuda:0", - "linears.54": "cuda:0", - "linears.55": "cuda:0", - "linears.56": "cuda:0", - "linears.57": "cuda:0", - "linears.58": "cuda:0", - "linears.59": "cuda:0", - "linears.60": "cuda:0", - "linears.61": "cuda:0", - "linears.62": "cuda:0", - "linears.63": "cuda:0" + "linears.2": "cpu", + "linears.3": "cpu", + "linears.4": "disk", + "linears.5": "disk", + "linears.6": "disk", + "linears.7": "disk", + "linears.8": "disk", + "linears.9": "disk", + "linears.10": "disk", + "linears.11": "disk", + "linears.12": "disk", + "linears.13": "disk", + "linears.14": "disk", + "linears.15": "disk", + "linears.16": "disk", + "linears.17": "disk", + "linears.18": "disk", + "linears.19": "disk", + "linears.20": "disk", + "linears.21": "disk", + "linears.22": "disk", + "linears.23": "disk", + "linears.24": "disk", + "linears.25": "disk", + "linears.26": "disk", + "linears.27": "disk", + "linears.28": "disk", + "linears.29": "disk", + "linears.30": "disk", + "linears.31": "disk", + "linears.32": "disk", + "linears.33": "disk", + "linears.34": "disk", + "linears.35": "disk", + "linears.36": "disk", + "linears.37": "disk", + "linears.38": "disk", + "linears.39": "disk", + "linears.40": "disk", + "linears.41": "disk", + "linears.42": "disk", + "linears.43": "disk", + "linears.44": "disk", + "linears.45": "disk", + "linears.46": "disk", + "linears.47": "disk", + "linears.48": "disk", + "linears.49": "disk", + "linears.50": "disk", + "linears.51": "disk", + "linears.52": "disk", + "linears.53": "disk", + "linears.54": "disk", + "linears.55": "disk", + "linears.56": "disk", + "linears.57": "disk", + "linears.58": "disk", + "linears.59": "disk", + "linears.60": "disk", + "linears.61": "disk", + "linears.62": "disk", + "linears.63": "disk" } \ No newline at end of file diff --git a/Tests/DynamicLoadingTest.cs b/Tests/DynamicLoadingTest.cs index 7817761..f182537 100644 --- a/Tests/DynamicLoadingTest.cs +++ b/Tests/DynamicLoadingTest.cs @@ -20,7 +20,7 @@ namespace Phi.Tests; public class DynamicLoadingTest { private ITestOutputHelper output; - private int testDimension = 1024 * 2 * 2; + private int testDimension = 1024 * 4; private int preHeat = 1; private int loops = 10; private int numLayers = 512; @@ -84,7 +84,15 @@ public async Task CPUBenchmarkAsync() var device = "cpu"; var input = torch.randn(testDimension, testDimension, device: device); var model = new SequentialLinear(testDimension, device); + await BenchmarkAsync(device, input, model); + } + [Fact] + public async Task QuantizeGPUBenchmarkAsync() + { + var device = "cuda:0"; + var input = torch.randn(testDimension, testDimension, device: device); + var model = new SequentialLinear(testDimension, device); await BenchmarkAsync(device, input, model); } @@ -161,7 +169,7 @@ private class SequentialLinear : nn.Module //private readonly DynamicLoadingModule linear4; //private readonly DynamicLoadingModule linear5; - private readonly ModuleList> linears; + private readonly ModuleList> linears; public SequentialLinear(int features, string? device = null, int numberOfLayers = 64) : base(nameof(SequentialLinear)) { @@ -171,11 +179,11 @@ public SequentialLinear(int features, string? device = null, int numberOfLayers //this.linear4 = DynamicLoadingModule.CreateFromModel(new PhiLinear(features, features, device: device)); //this.linear5 = DynamicLoadingModule.CreateFromModel(new PhiLinear(features, features, device: device)); - this.linears = new ModuleList>(); + this.linears = new ModuleList>(); for (int i = 0; i < numberOfLayers; i++) { - this.linears.Add(DynamicLoadingModule.CreateFromModel(new PhiLinear(features, features, device: device))); + this.linears.Add(DynamicLoadingModule.CreateFromModel(new PhiInt8Linear(features, features, device: device))); } this.RegisterComponents(); diff --git a/Tests/PhiLint4LinearTests.cs b/Tests/PhiLint4LinearTests.cs new file mode 100644 index 0000000..3b023cc --- /dev/null +++ b/Tests/PhiLint4LinearTests.cs @@ -0,0 +1,86 @@ +using FluentAssertions; +using Phi.Module; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TorchSharp; +using Xunit; +using static TorchSharp.torch; + +namespace Phi.Tests; + +public class PhiLint4LinearTests +{ + [Fact] + public void SizeTests() + { + // meta is critical for the test + // as the size of the model to test is 372 GB + // and can't be loaded in real device like cpu or cuda + var device = "meta"; + var model = new PhiInt4Linear(100000, 100, device: device); + + var sizeInBytes = model.GetSizeInBytes(); + + var sizeInGigaBytes = sizeInBytes / 1024 / 1024; + sizeInGigaBytes.Should().Be(38); + + // to int8 + model.Quantize(); + var sizeInBytesAfterInt8 = model.GetSizeInBytes(); + + var sizeInGigaBytesAfterInt8 = sizeInBytesAfterInt8 / 1024 / 1024; + sizeInGigaBytesAfterInt8.Should().Be(4); + } + + [Fact] + public void ForwardTest() + { + var device = "cpu"; + var model = new PhiInt4Linear(123, 10, device: device); + + // set both weight and bias to rand int8 values + // and compare the result before and after ToInt8 + + var input = torch.randint(-8, 7, [10 ,2200, 123], device: device); + var weight = torch.randint(-8, 7, [10, 123], device: device); + var bias = torch.randint(-8, 7, [10], device: device); + + var weightStr = weight.Peek("weight").ToString(); + + weight = (weight + 8).view(-1).to(torch.uint8); + var weightPlaceHolderDim = (int)weight.size(0); + weightPlaceHolderDim = weightPlaceHolderDim / 2 + weightPlaceHolderDim % 2; + var weightPlaceHolder = weight[..weightPlaceHolderDim]; + weightPlaceHolder = weightPlaceHolder * 16 + weight[weightPlaceHolderDim..]; + + var high4Bit = weightPlaceHolder / 16; + var low4Bit = weightPlaceHolder % 16; + weight = torch.cat(new Tensor[] { high4Bit, low4Bit }).view(10, 123); + weight = weight.to(torch.int64); + weight -= 8; + weight.Peek("weight").Should().Be(weightStr); + + // scale and zero point on vector-wise + //input = input * (250 / (torch.max(input) - torch.min(input))); + //weight = weight * (250 / (torch.max(weight) - torch.min(weight))).view(-1, 1); + //bias = bias * (250 / (torch.max(bias) - torch.min(bias))); + + model.load_state_dict(new Dictionary + { + ["weight"] = weight, + ["bias"] = bias + }); + + var resultBeforeInt8 = model.forward(input); + + model.Quantize(); + + var resultAfterInt8 = model.forward(input); + + // compare the result + resultBeforeInt8.Peek("result").Should().Be(resultAfterInt8.Peek("result")); + } +} diff --git a/Tests/PhiLint8LinearTests.cs b/Tests/PhiLint8LinearTests.cs new file mode 100644 index 0000000..f4d2928 --- /dev/null +++ b/Tests/PhiLint8LinearTests.cs @@ -0,0 +1,128 @@ +using FluentAssertions; +using Phi.Module; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TorchSharp; +using Xunit; +using Xunit.Abstractions; +using static TorchSharp.torch; + +namespace Phi.Tests; + +public class PhiLint8LinearTests +{ + private ITestOutputHelper output; + + public PhiLint8LinearTests(ITestOutputHelper output) + { + this.output = output; + } + + [Fact] + public void SizeTests() + { + // meta is critical for the test + // as the size of the model to test is 372 GB + // and can't be loaded in real device like cpu or cuda + var device = "meta"; + var model = new PhiInt8Linear(100000, 100, device: device); + + var sizeInBytes = model.GetSizeInBytes(); + + var sizeInGigaBytes = sizeInBytes / 1024 / 1024; + sizeInGigaBytes.Should().Be(38); + + // to int8 + model.Quantize(); + var sizeInBytesAfterInt8 = model.GetSizeInBytes(); + + var sizeInGigaBytesAfterInt8 = sizeInBytesAfterInt8 / 1024 / 1024; + sizeInGigaBytesAfterInt8.Should().Be(9); + } + + [Fact] + public void ForwardTest() + { + var device = "cpu"; + var model = new PhiInt8Linear(123, 10, device: device); + + // set both weight and bias to rand int8 values + // and compare the result before and after ToInt8 + + var input = torch.randint(-128, 127, [10, 2200, 123], device: device); + var weight = torch.randint(-128, 127, [10, 123], device: device); + var bias = torch.randint(-128, 127, [10], device: device); + + // scale and zero point on vector-wise + //input = input * (250 / (torch.max(input) - torch.min(input))); + //weight = weight * (250 / (torch.max(weight) - torch.min(weight))).view(-1, 1); + //bias = bias * (250 / (torch.max(bias) - torch.min(bias))); + + model.load_state_dict(new Dictionary + { + ["weight"] = weight, + ["bias"] = bias + }); + + var resultBeforeInt8 = model.forward(input); + + model.Quantize(); + + var resultAfterInt8 = model.forward(input); + + // compare the result + resultBeforeInt8.Peek("result").Should().Be(resultAfterInt8.Peek("result")); + } + + [Fact] + public void MatMulitBenchmark() + { + var sizeX = new long[] { 1, 1000, 1000 }; + var sizeY = new long[] { 1000, 100 }; + var device = "cpu"; + // float32 + var x = torch.randn(sizeX, device: device); + var y = torch.randn(sizeY, device: device); + + // warm up + for (var i = 0; i < 10; i++) + { + var _ = torch.matmul(x, y); + } + + // measure + var timer = System.Diagnostics.Stopwatch.StartNew(); + for (var i = 0; i < 10; i++) + { + var _ = torch.matmul(x, y); + } + timer.Stop(); + + output.WriteLine($"MatMulitBenchmark elapsed time: {timer.ElapsedMilliseconds} ms"); + + // int8 + var xInt8 = x.to(ScalarType.Int8); + var yInt8 = y.to(ScalarType.Int8); + + // warm up + for (var i = 0; i < 10; i++) + { + var _ = torch.matmul(xInt8, yInt8); + } + + // measure + timer.Restart(); + for (var i = 0; i < 10; i++) + { + var _ = torch.matmul(xInt8, yInt8); + } + + timer.Stop(); + + output.WriteLine($"MatMulitBenchmark int8 elapsed time: {timer.ElapsedMilliseconds} ms"); + + } +} diff --git a/Torchsharp-phi.csproj b/Torchsharp-phi.csproj index 8dcdd33..1e226b0 100644 --- a/Torchsharp-phi.csproj +++ b/Torchsharp-phi.csproj @@ -23,7 +23,7 @@ - +