Skip to content
This repository was archived by the owner on Aug 22, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions Extension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ public static Dictionary<string, long> GetSizeForEachDynamicLayerInBytes(this nn
}
}

public static void ToQuantizedModule<T>(
this T model)
where T : nn.Module
{
foreach (var (_, value) in model.named_children())
{
if (value is IQuantizeModule quantizeModule)
{
quantizeModule.Quantize();
}
else
{
value.ToQuantizedModule();
}
}
}

public static T ToDynamicLoadingModel<T>(
this T model,
Dictionary<string, string> deviceMap,
Expand Down
5 changes: 5 additions & 0 deletions Module/IDynamicLoadModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ public interface IDynamicLoadModule

public Action<nn.Module>? UnloadFromDeviceFunc { get; set; }
}

public interface IQuantizeModule
{
public void Quantize();
}
9 changes: 5 additions & 4 deletions Module/Phi3Attention.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using FluentAssertions;
using Phi.Module;
using System;
using System.Collections.Generic;
using System.Linq;
Expand Down Expand Up @@ -71,8 +72,8 @@ public class Phi3Attention : nn.Module<Phi3AttentionInput, Phi3AttentionOutput>
private readonly Dictionary<string, object>? rope_scaling;
private readonly bool is_causal;

private readonly Linear o_proj;
private readonly Linear qkv_proj;
private readonly PhiLinear o_proj;
private readonly PhiLinear qkv_proj;
private nn.Module<Phi3RotaryEmbeddingInput, Phi3RotaryEmbeddingOutput> rotary_emb;

public Phi3Attention(Phi3Config config, int layer_idx)
Expand All @@ -95,8 +96,8 @@ public Phi3Attention(Phi3Config config, int layer_idx)
(this.head_dim * this.num_heads).Should().Be(this.hidden_size, "hidden_size must be divisible by num_heads");

var op_size = this.num_heads * this.head_dim + 2 * (this.num_key_value_heads * this.head_dim);
this.o_proj = nn.Linear(this.num_heads * this.head_dim, this.hidden_size, hasBias: false, dtype: config.DType);
this.qkv_proj = nn.Linear(this.hidden_size, op_size, hasBias: false, dtype: config.DType);
this.o_proj = new PhiInt4Linear(this.num_heads * this.head_dim, this.hidden_size, hasBias: false, dtype: config.DType);
this.qkv_proj = new PhiInt4Linear(this.hidden_size, op_size, hasBias: false, dtype: config.DType);
this._init_rope();
}

Expand Down
10 changes: 6 additions & 4 deletions Module/Phi3MLP.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
using static TorchSharp.torch;
using TorchSharp.Modules;
using TorchSharp;
using Phi.Module;
using Phi.Tests;

public class Phi3MLP : torch.nn.Module<Tensor, Tensor>
{
private readonly Linear gate_up_proj;
private readonly Linear down_proj;
private readonly PhiLinear gate_up_proj;
private readonly PhiLinear down_proj;
private readonly torch.nn.Module<Tensor, Tensor> activation_fn;

public Phi3MLP(Phi3Config config)
Expand All @@ -21,8 +23,8 @@ public Phi3MLP(Phi3Config config)
public Phi3MLP(int hiddenSize, int intermediateSize, string hiddenAct, ScalarType dtype)
: base(nameof(Phi3MLP))
{
this.gate_up_proj = torch.nn.Linear(hiddenSize, 2 * intermediateSize, hasBias: false, dtype: dtype);
this.down_proj = torch.nn.Linear(intermediateSize, hiddenSize, hasBias: false, dtype: dtype);
this.gate_up_proj = new PhiInt4Linear(hiddenSize, 2 * intermediateSize, hasBias: false, dtype: dtype);
this.down_proj = new PhiInt4Linear(intermediateSize, hiddenSize, hasBias: false, dtype: dtype);
this.RegisterComponents();
this.activation_fn = Utils.GetActivation(hiddenAct);
}
Expand Down
106 changes: 106 additions & 0 deletions Module/PhiInt4Linear.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using TorchSharp;
using static TorchSharp.torch;

namespace Phi.Module;

public class PhiInt4Linear : PhiLinear, IQuantizeModule
{
public PhiInt4Linear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null)
: base(inFeatures, outFeatures, hasBias, dtype, device)
{
}

public void Quantize()
{
using var _ = NewDisposeScope();
var timer = new System.Diagnostics.Stopwatch();
Console.WriteLine("Quantize start");
timer.Start();
// scale and zero point on vector-wise
// scale = 15 / max(weight, axis=1) - min(weight, axis=1)
var scale = 15 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values);

// zero point = - scale * min(weight, axis=1) - 8
var zeroPoint = - scale * torch.min(this.weight, 1).values - 8;
// round zero point to nearest integer
zeroPoint = torch.round(zeroPoint);
var _4bitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8);

zeroPoint = (zeroPoint + 8).to(torch.uint8);
_4bitWeight = (_4bitWeight + 8).view(-1).to(torch.uint8);

// torch doesn't provide int4, so we use int8 as placeholder
// and foreach int8, we save two int4, e.g. 0b1010 -> 0b10, 0b10
var placeHolderDim = this.outFeatures / 2 + this.outFeatures % 2;
var zpPlaceHolder = zeroPoint[..placeHolderDim];
zpPlaceHolder = zpPlaceHolder * 16 + zeroPoint[placeHolderDim..];

// assert zero point is in range [-128, 127]
//if (torch.any(this.zeroPoint < -128).item<bool>() || torch.any(this.zeroPoint > 127).item<bool>())
//{
// throw new Exception("Zero point is out of range [-128, 127]");
//}

// quantize weight
var _4bitWeightPlaceHolderDim =Convert.ToInt32(_4bitWeight.size(0) / 2 + _4bitWeight.size(0) % 2);
var _4bitWeightPlaceHolder = _4bitWeight[.._4bitWeightPlaceHolderDim];
_4bitWeightPlaceHolder = _4bitWeightPlaceHolder * 16 + _4bitWeight[_4bitWeightPlaceHolderDim..];

// assert weight is in range [-128, 127]
//if (torch.any(this._8bitWeight < -128).item<bool>() || torch.any(this._8bitWeight > 127).item<bool>())
//{
// throw new Exception("Weight is out of range [-128, 127]");
//}

// dispose float32 weight
this.weight.Dispose();
this.weight = null;

this._internal_buffers.Remove("weight");
this.register_buffer("4bit_weight", _4bitWeightPlaceHolder.MoveToOuterDisposeScope());
this.register_buffer("zeroPoint", zpPlaceHolder.MoveToOuterDisposeScope());
this.register_buffer("scale", scale.MoveToOuterDisposeScope());
timer.Stop();
Console.WriteLine($"Quantize end, elapsed time: {timer.ElapsedMilliseconds} ms");
}

public override Tensor forward(Tensor input)
{
if (this._internal_buffers.ContainsKey("weight"))
{
return base.forward(input);
}
else
{
using var dispose = torch.NewDisposeScope();
var weight = this.get_buffer("4bit_weight");
var weightLower = weight % 16;
var weightUpper = weight / 16;
weight = torch.cat([weightUpper, weightLower], 0).to(ScalarType.Float32);
weight = weight.view(this.outFeatures, this.inFeatures);
weight -= 8;
var zeroPoint = this.get_buffer("zeroPoint");
var zeroPointLower = zeroPoint % 16;
var zeroPointUpper = zeroPoint / 16;
zeroPoint = torch.cat([zeroPointUpper, zeroPointLower], 0).to(ScalarType.Float32);
zeroPoint -= 8;
var scale = this.get_buffer("scale").to(ScalarType.Float32);
var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1);
// use float32
var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T);

if (this.bias is not null)
{
result = result + this.bias.to_type(ScalarType.Float32);
}

//result.Peek("result");
return result.to_type(input.dtype).MoveToOuterDisposeScope();
}
}
}
88 changes: 88 additions & 0 deletions Module/PhiInt8Linear.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using TorchSharp;
using static TorchSharp.torch;

namespace Phi.Module;

public class PhiInt8Linear : PhiLinear, IQuantizeModule
{
//private Tensor? scale;
//private Tensor? zeroPoint;
//private Tensor? _8bitWeight;

public PhiInt8Linear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null)
: base(inFeatures, outFeatures, hasBias, dtype, device)
{
}

public void Quantize()
{
var timer = new System.Diagnostics.Stopwatch();
Console.WriteLine("Quantize start");
timer.Start();
// scale and zero point on vector-wise
// scale = 255 / max(weight, axis=1) - min(weight, axis=1)
var scale = 255 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values);

// zero point = - scale * min(weight, axis=1) - 128
var zeroPoint = - scale * torch.min(this.weight, 1).values - 128;
// round zero point to nearest integer
zeroPoint = torch.round(zeroPoint).to(torch.int8);

// assert zero point is in range [-128, 127]
//if (torch.any(this.zeroPoint < -128).item<bool>() || torch.any(this.zeroPoint > 127).item<bool>())
//{
// throw new Exception("Zero point is out of range [-128, 127]");
//}

// quantize weight
var _8bitWeight = torch.round(this.weight * scale.view(-1, 1)+ zeroPoint.view(-1, 1)).to(torch.int8);

// assert weight is in range [-128, 127]
//if (torch.any(this._8bitWeight < -128).item<bool>() || torch.any(this._8bitWeight > 127).item<bool>())
//{
// throw new Exception("Weight is out of range [-128, 127]");
//}

// dispose float32 weight
this.weight.Dispose();
this.weight = null;

this._internal_buffers.Remove("weight");
this.register_buffer("8bit_weight", _8bitWeight);
this.register_buffer("zeroPoint", zeroPoint);
this.register_buffer("scale", scale);
timer.Stop();
Console.WriteLine($"Quantize end, elapsed time: {timer.ElapsedMilliseconds} ms");
}

public override Tensor forward(Tensor input)
{
if (this._internal_buffers.ContainsKey("weight"))
{
return base.forward(input);
}
else
{
using var dispose = torch.NewDisposeScope();
var weight = this.get_buffer("8bit_weight").to(ScalarType.Float32);
var zeroPoint = this.get_buffer("zeroPoint").to(ScalarType.Float32);
var scale = this.get_buffer("scale").to(ScalarType.Float32);
var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1);
// use float32
var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T);

if (this.bias is not null)
{
result = result + this.bias.to_type(ScalarType.Float32);
}

//result.Peek("result");
return result.to_type(input.dtype).MoveToOuterDisposeScope();
}
}
}
9 changes: 5 additions & 4 deletions Module/PhiLinear.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

public class PhiLinear : nn.Module<Tensor, Tensor>
{
private readonly Tensor weight;
private readonly Tensor? bias;
private int inFeatures;
private int outFeatures;
protected Tensor? weight;
protected readonly Tensor? bias;
protected int inFeatures;
protected int outFeatures;

public PhiLinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null)
: base(nameof(PhiLinear))
{
this.inFeatures = inFeatures;
this.outFeatures = outFeatures;
device = device ?? "cpu";
this.weight = torch.randn(outFeatures, inFeatures, dtype: dtype, device: device);

if (hasBias)
Expand Down
2 changes: 1 addition & 1 deletion Phi3/Phi3ForCasualLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public static Phi3ForCasualLM FromPretrained(
modelConfig.DType = torchDtype;
var phi = new Phi3ForCasualLM(modelConfig);
var loadedParameters = new Dictionary<string, bool>();
phi.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedParameters);
phi.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedParameters, useTqdm: false);
phi = phi.to(device);
phi.eval();

Expand Down
14 changes: 7 additions & 7 deletions Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using TorchSharp;
using static TorchSharp.torch;

var phi2Folder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var phiFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
var device = "cpu";

if (device == "cuda")
Expand All @@ -22,22 +22,22 @@

Console.WriteLine("Loading Phi3 from huggingface model weight folder");
var timer = System.Diagnostics.Stopwatch.StartNew();
var model = Phi3ForCasualLM.FromPretrained(phi2Folder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json");
var tokenizer = LLama2Tokenizer.FromPretrained(phi2Folder);
var model = Phi3ForCasualLM.FromPretrained(phiFolder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json");
var tokenizer = LLama2Tokenizer.FromPretrained(phiFolder);

var deviceSizeMap = new Dictionary<string, long>
{
["cuda:0"] = 0L * 1024 * 1024 * 1024,
["cuda:0"] = 5L * 1024 * 1024 * 1024,
["cpu"] = 64L * 1024 * 1024 * 1024,
["disk"] = 2L * 1024 * 1024 * 1024 * 1024,
};
model.ToQuantizedModule();
var deviceMap = model.InferDeviceMapForEachLayer(
devices: [ "cuda:0", "cpu", "disk" ],
deviceSizeMapInByte: deviceSizeMap);

var json = JsonSerializer.Serialize(deviceMap, new JsonSerializerOptions { WriteIndented = true });
Console.WriteLine(json);

model = model.ToDynamicLoadingModel(deviceMap, "cuda:0");

var pipeline = new CasualLMPipeline(tokenizer, model, device);
Expand All @@ -49,8 +49,8 @@
// agent
var agent = new Phi3Agent(pipeline, "assistant")
.RegisterPrintMessage();
var question = @"count to 3";
var systemMessage = new TextMessage(Role.System, "You are a helpful AI assistant that always respond in JSON format");
var question = @"Use C# to calculate 100th fibonacci";
var systemMessage = new TextMessage(Role.System, "You are a helpful AI assistant.");
var userMessage = new TextMessage(Role.User, question);
for (int i = 0; i!= 100; ++i)
{
Expand Down
Loading