diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index d3985d1777..c55f5797f2 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -184,7 +184,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Phi.Test
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Samples", "docs\samples\Microsoft.ML.GenAI.Samples\Microsoft.ML.GenAI.Samples.csproj", "{1D4AD9A3-19AF-432B-889D-A63FE6D7BD47}"
EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Core.Tests", "test\Microsoft.ML.GenAI.Core.Tests\Microsoft.ML.GenAI.Core.Tests.csproj", "{14AB0804-D4CE-4634-B544-5A8587620783}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.Core.Tests", "test\Microsoft.ML.GenAI.Core.Tests\Microsoft.ML.GenAI.Core.Tests.csproj", "{14AB0804-D4CE-4634-B544-5A8587620783}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.GenAI.LLaMA", "src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj", "{0AA6D5CB-195F-457A-8792-4221E76E6C44}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.LLaMA.Tests", "test\Microsoft.ML.GenAI.LLaMA.Tests\Microsoft.ML.GenAI.LLaMA.Tests.csproj", "{D202353D-6FAF-4263-9A01-BDCFBC92391F}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -878,6 +882,22 @@ Global
{14AB0804-D4CE-4634-B544-5A8587620783}.Release|Any CPU.Build.0 = Release|Any CPU
{14AB0804-D4CE-4634-B544-5A8587620783}.Release|x64.ActiveCfg = Release|Any CPU
{14AB0804-D4CE-4634-B544-5A8587620783}.Release|x64.Build.0 = Release|Any CPU
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44}.Debug|x64.Build.0 = Debug|Any CPU
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44}.Release|Any CPU.Build.0 = Release|Any CPU
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44}.Release|x64.ActiveCfg = Release|Any CPU
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44}.Release|x64.Build.0 = Release|Any CPU
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Debug|x64.Build.0 = Debug|Any CPU
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.ActiveCfg = Release|Any CPU
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -969,6 +989,8 @@ Global
{867FFC34-DFA7-400F-B9BB-85158326CE08} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{1D4AD9A3-19AF-432B-889D-A63FE6D7BD47} = {DA452A53-2E94-4433-B08C-041EDEC729E6}
{14AB0804-D4CE-4634-B544-5A8587620783} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {0AA6D5CB-195F-457A-8792-4221E76E6C44} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {D202353D-6FAF-4263-9A01-BDCFBC92391F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/NuGet.config b/NuGet.config
index 15f4fc551b..5f023aa721 100644
--- a/NuGet.config
+++ b/NuGet.config
@@ -13,6 +13,7 @@
+
@@ -40,6 +41,9 @@
+
+
+
diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LLaMA3_1.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LLaMA3_1.cs
new file mode 100644
index 0000000000..49fcdf5892
--- /dev/null
+++ b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LLaMA3_1.cs
@@ -0,0 +1,51 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Text.Json;
+using System.Threading.Tasks;
+using AutoGen.Core;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.Core.Extension;
+using Microsoft.ML.GenAI.LLaMA;
+using Microsoft.ML.Tokenizers;
+using TorchSharp;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.Samples.Llama;
+
+internal class LlamaSample
+{
+ public static async void Run()
+ {
+ var device = "cuda";
+ if (device == "cuda")
+ {
+ torch.InitializeDeviceType(DeviceType.CUDA);
+ }
+
+ var defaultType = ScalarType.Float16;
+ torch.manual_seed(1);
+ torch.set_default_dtype(defaultType);
+ var weightFolder = @"C:\Users\xiaoyuz\source\repos\Meta-Llama-3.1-8B-Instruct";
+ var configName = "config.json";
+ var originalWeightFolder = Path.Combine(weightFolder, "original");
+
+ Console.WriteLine("Loading Llama from huggingface model weight folder");
+ var stopWatch = System.Diagnostics.Stopwatch.StartNew();
+ stopWatch.Start();
+ var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
+ var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, layersOnTargetDevice: -1);
+
+ var pipeline = new CausalLMPipeline(tokenizer, model, device);
+
+ var agent = new LlamaCausalLMAgent(pipeline, "assistant")
+ .RegisterPrintMessage();
+
+ var task = """
+ Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
+ """;
+
+ await agent.SendAsync(task);
+ }
+}
diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj
index 0331a32fc1..d9932106d6 100644
--- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj
+++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj
@@ -9,6 +9,7 @@
+
diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs
index 379fd2b97b..392aec674d 100644
--- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs
+++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs
@@ -26,7 +26,7 @@ public static async Task RunAsync()
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
- var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device);
+ var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false);
// agent
var agent = new Phi3Agent(pipeline, "assistant")
diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs
index 5e53ef0ac4..33819a8df4 100644
--- a/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs
+++ b/docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs
@@ -20,7 +20,7 @@ public static ICausalLMPipeline LoadPhi3Mini4KFromFo
string weightFolder,
string configName = "config.json",
string device = "cuda",
- int modelSizeOnCudaInGB = 16,
+ int modelSizeOnCudaInGB = 55,
int modelSizeOnMemoryInGB = 64,
int modelSizeOnDiskInGB = 200,
bool quantizeToInt8 = false,
diff --git a/eng/Versions.props b/eng/Versions.props
index 84b28e1b8f..3b7fe5bd01 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -96,7 +96,7 @@
0.0.13-test
0.0.6-test
0.0.7-test
- 2.0.0-beta.24219.1
+ 2.0.0-beta.24415.1
4.8.6
1.0.118
1.6.24
diff --git a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs
index 18633728a5..a904c394b9 100644
--- a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs
+++ b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs
@@ -197,6 +197,57 @@ public static Dictionary InferDeviceMapForEachLayer(
return deviceMap;
}
+ ///
+ /// Infer the device map for each layer in the model.
+ /// The device map is a dictionary where the key is the device id (e.g. "cuda:0") and the value is the memory size in bytes of the device.
+ /// When inferring the device map, each layer in the model will be placed on the device in the order of the devices list.
+ ///
+ ///
+ /// a list of key-value pairs where the key is the device id (e.g. "cuda:0") and the value is the number of layers to be placed on the device.
+ /// If you want to place all remaining layers on the device, set that value to -1.
+ /// e.g. [{"cuda:0", 2}, {"cpu", -1}], the first 2 layers will be placed on "cuda:0" and the rest will be placed on "cpu".
+ ///
+ ///
+ public static Dictionary InferDeviceMapForEachLayer(
+ this nn.Module model,
+ IEnumerable> numberOfLayerToBePlaced)
+ {
+ var layerSizeMap = model.GetSizeForEachDynamicLayerInBytes()
+ .OrderByDescending(x => x.Value)
+ .ToList();
+
+ var deviceMap = new Dictionary();
+ foreach (var (device, count) in numberOfLayerToBePlaced)
+ {
+ if (count != -1)
+ {
+ var topK = layerSizeMap.Take(count).ToList();
+ layerSizeMap = layerSizeMap.Skip(count).ToList();
+ foreach (var (key, value) in topK)
+ {
+ deviceMap[key] = device;
+ }
+ }
+ else
+ {
+ foreach (var (key, value) in layerSizeMap)
+ {
+ deviceMap[key] = device;
+ }
+
+ layerSizeMap.Clear();
+ break;
+ }
+ }
+
+ if (layerSizeMap.Count > 0)
+ {
+ throw new ArgumentException("The layer count is not enough to cover all layers, did you forget to set the last layer count to -1?");
+ }
+
+ return deviceMap;
+ }
+
internal static string Peek(this nn.Module model)
{
var sb = new StringBuilder();
diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
index dfb64082fb..8745b81c6d 100644
--- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
+++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
@@ -8,16 +8,11 @@
+
+
-
@@ -25,6 +20,8 @@
+
+
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs b/src/Microsoft.ML.GenAI.Core/Module/Attention.cs
similarity index 60%
rename from src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs
rename to src/Microsoft.ML.GenAI.Core/Module/Attention.cs
index 72c7c8946a..869c213b74 100644
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Attention.cs
+++ b/src/Microsoft.ML.GenAI.Core/Module/Attention.cs
@@ -9,17 +9,19 @@
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
-namespace Microsoft.ML.GenAI.Phi.Module;
+namespace Microsoft.ML.GenAI.Core;
-internal class Phi3AttentionInput
+internal class AttentionInput
{
- public Phi3AttentionInput(
+ public AttentionInput(
Tensor hiddenStates,
Tensor positionIds,
+ RotaryEmbeddingOutput positionalEmbeddings, // cos, sin
Tensor? attentionMask = null,
IKVCache? cache = null,
bool outputAttentions = false)
@@ -28,6 +30,7 @@ public Phi3AttentionInput(
this.AttentionMask = attentionMask;
this.PositionIds = positionIds;
this.Cache = cache;
+ this.PositionalEmbeddings = positionalEmbeddings;
this.OutputAttentions = outputAttentions;
}
public Tensor HiddenStates { get; set; }
@@ -36,14 +39,16 @@ public Phi3AttentionInput(
public Tensor PositionIds { get; set; }
+ public RotaryEmbeddingOutput PositionalEmbeddings { get; set; }
+
public IKVCache? Cache { get; set; }
public bool OutputAttentions { get; set; }
}
-internal class Phi3AttentionOutput
+internal class AttentionOutput
{
- public Phi3AttentionOutput(
+ public AttentionOutput(
Tensor hiddenStates,
Tensor? attentions = null,
IKVCache? cache = null)
@@ -60,9 +65,8 @@ public Phi3AttentionOutput(
public IKVCache? Cache { get; set; }
}
-internal class Phi3Attention : nn.Module
+internal class Attention : nn.Module
{
- private readonly Phi3Config _config;
private readonly int _layerIdx;
private readonly double _attentionDropout;
private readonly int _hiddenSize;
@@ -72,52 +76,57 @@ internal class Phi3Attention : nn.Module? _ropeScaling;
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
private readonly QuantizedLinear o_proj;
- private readonly QuantizedLinear qkv_proj;
- private nn.Module rotary_emb = null!;
+ private readonly QuantizedLinear? qkv_proj;
+ private readonly QuantizedLinear? q_proj;
+ private readonly QuantizedLinear? k_proj;
+ private readonly QuantizedLinear? v_proj;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
- public Phi3Attention(Phi3Config config, int layerIdx)
- : base(nameof(Phi3Attention))
+ public Attention(
+ double attentionDropout,
+ int hiddenSize,
+ int numHeads,
+ int headDim,
+ int numKeyValueHeads,
+ int numKeyValueGroups,
+ int maxPositionEmbeddings,
+ int originalMaxPositionEmbeddings,
+ int layerIdx,
+ ScalarType dtype,
+ bool attentionBias = false,
+ bool useQkvProj = true)
+ : base(nameof(Attention))
{
- this._config = config;
this._layerIdx = layerIdx;
- this._attentionDropout = config.AttentionDropout;
- this._hiddenSize = config.HiddenSize;
- this._numHeads = config.NumAttentionHeads;
- this._headDim = this._hiddenSize / this._numHeads;
- this._numKeyValueHeads = config.NumKeyValueHeads ?? throw new ArgumentException("num_key_value_heads must be specified");
- this._numKeyValueGroups = this._numHeads / this._numKeyValueHeads;
- this._maxPositionEmbeddings = config.MaxPositionEmbeddings;
- this._originalMaxPositionEmbeddings = config.OriginalMaxPositionEmbeddings;
- this._ropeTheta = config.RopeTheta;
- this._ropeScaling = config.RopeScaling;
+ this._attentionDropout = attentionDropout;
+ this._hiddenSize = hiddenSize;
+ this._numHeads = numHeads;
+ this._headDim = headDim;
+ this._numKeyValueHeads = numKeyValueHeads;
+ this._numKeyValueGroups = numKeyValueGroups;
+ this._maxPositionEmbeddings = maxPositionEmbeddings;
+ this._originalMaxPositionEmbeddings = originalMaxPositionEmbeddings;
Contract.Assert(this._hiddenSize % (this._headDim * this._numHeads) == 0, "hidden_size must be divisible by num_heads");
- var opSize = this._numHeads * this._headDim + 2 * (this._numKeyValueHeads * this._headDim);
- this.o_proj = new QuantizedLinear(this._numHeads * this._headDim, this._hiddenSize, hasBias: false, dtype: config.DType);
- this.qkv_proj = new QuantizedLinear(this._hiddenSize, opSize, hasBias: false, dtype: config.DType);
- this.InitRope();
- }
-
- private void InitRope()
- {
- if (this._ropeScaling is null)
+ this.o_proj = new QuantizedLinear(this._hiddenSize, this._hiddenSize, hasBias: attentionBias, dtype: dtype);
+ if (useQkvProj)
{
- this.rotary_emb = new Phi3RotaryEmbedding(this._ropeTheta, this._maxPositionEmbeddings, this._headDim);
+ var opSize = this._numHeads * this._headDim + 2 * (this._numKeyValueHeads * this._headDim);
+ this.qkv_proj = new QuantizedLinear(this._hiddenSize, opSize, hasBias: attentionBias, dtype: dtype);
}
else
{
- this.rotary_emb = new Phi3SuScaledRotaryEmbedding(this._headDim, this._config);
+ this.q_proj = new QuantizedLinear(this._hiddenSize, this._numHeads * this._headDim, hasBias: attentionBias, dtype: dtype);
+ this.k_proj = new QuantizedLinear(this._hiddenSize, this._numKeyValueHeads * this._headDim, hasBias: attentionBias, dtype: dtype);
+ this.v_proj = new QuantizedLinear(this._hiddenSize, this._numKeyValueHeads * this._headDim, hasBias: attentionBias, dtype: dtype);
}
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
- public override Phi3AttentionOutput forward(Phi3AttentionInput input)
+ public override AttentionOutput forward(AttentionInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
using (var _ = NewDisposeScope())
@@ -128,26 +137,39 @@ public override Phi3AttentionOutput forward(Phi3AttentionInput input)
var bsz = hiddenStates.shape[0];
var qLen = hiddenStates.shape[1];
- var qkv = this.qkv_proj.forward(hiddenStates);
- var queryPos = this._numHeads * this._headDim;
- var queryStates = qkv[.., .., ..queryPos];
- var keyStates = qkv[.., .., queryPos..(queryPos + this._numKeyValueHeads * this._headDim)];
- var valueStates = qkv[.., .., (queryPos + this._numKeyValueHeads * this._headDim)..];
+ Tensor queryStates;
+ Tensor keyStates;
+ Tensor valueStates;
+
+ if (this.qkv_proj is not null)
+ {
+ var qkv = this.qkv_proj.forward(hiddenStates);
+ var queryPos = this._numHeads * this._headDim;
+ queryStates = qkv[.., .., ..queryPos];
+ keyStates = qkv[.., .., queryPos..(queryPos + this._numKeyValueHeads * this._headDim)];
+ valueStates = qkv[.., .., (queryPos + this._numKeyValueHeads * this._headDim)..];
+ }
+ else if (this.q_proj is not null && this.k_proj is not null && this.v_proj is not null)
+ {
+ queryStates = this.q_proj.forward(hiddenStates);
+ keyStates = this.k_proj.forward(hiddenStates);
+ valueStates = this.v_proj.forward(hiddenStates);
+ }
+ else
+ {
+ throw new InvalidOperationException("Invalid state, either qkv_proj or q_proj, k_proj, v_proj should be initialized");
+ }
+
queryStates = queryStates.view(bsz, qLen, this._numHeads, this._headDim).transpose(1, 2);
keyStates = keyStates.view(bsz, qLen, this._numKeyValueHeads, this._headDim).transpose(1, 2);
valueStates = valueStates.view(bsz, qLen, this._numKeyValueHeads, this._headDim).transpose(1, 2);
-
var kvSeqLen = keyStates.IntShape()[^2];
var pastKeyValue = input.Cache;
if (pastKeyValue is not null)
{
kvSeqLen += pastKeyValue.GetUsableLength(kvSeqLen, this._layerIdx);
}
-
- var embOutput = this.rotary_emb.forward(new Phi3RotaryEmbeddingInput(valueStates, positionIds, kvSeqLen));
- (var cos, var sin) = (embOutput.Cos, embOutput.Sin);
-
- (queryStates, keyStates) = Utils.ApplyRotaryPosEmb(queryStates, keyStates, cos, sin);
+ (queryStates, keyStates) = Utils.ApplyRotaryPosEmb(queryStates, keyStates, input.PositionalEmbeddings.Cos, input.PositionalEmbeddings.Sin);
if (pastKeyValue is not null)
{
@@ -155,9 +177,10 @@ public override Phi3AttentionOutput forward(Phi3AttentionInput input)
}
// repeat k/v heads if n_kv_heads < n_heads
- keyStates = Utils.Phi3RepeatKV(keyStates, this._numKeyValueGroups);
- valueStates = Utils.Phi3RepeatKV(valueStates, this._numKeyValueGroups);
+ keyStates = Utils.RepeatKV(keyStates, this._numKeyValueGroups);
+ valueStates = Utils.RepeatKV(valueStates, this._numKeyValueGroups);
+ // to fp32 to avoid overflow
var attnWeights = torch.matmul(queryStates, keyStates.transpose(2, 3));
attnWeights = attnWeights / Math.Sqrt(this._headDim);
@@ -175,7 +198,7 @@ public override Phi3AttentionOutput forward(Phi3AttentionInput input)
Contract.Assert(attentionMask.shape[0] == bsz);
Contract.Assert(attentionMask.shape[1] == 1);
Contract.Assert(attentionMask.shape[2] == qLen);
- Contract.Assert(attentionMask.shape[3] == kvSeqLen);
+ //Contract.Assert(attentionMask.shape[3] == kvSeqLen);
attnWeights = attnWeights + attentionMask;
}
diff --git a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs
index 77bcadeb82..178b8fddda 100644
--- a/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs
+++ b/src/Microsoft.ML.GenAI.Core/Module/GenAILinear.cs
@@ -5,7 +5,7 @@
using TorchSharp;
using static TorchSharp.torch;
-namespace Microsoft.ML.GenAI;
+namespace Microsoft.ML.GenAI.Core;
internal class GenAILinear : nn.Module
{
#pragma warning disable MSML_GeneralName // This name should be PascalCased
diff --git a/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs b/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs
index 4c46e53104..a1b523a4df 100644
--- a/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs
+++ b/src/Microsoft.ML.GenAI.Core/Module/NewGELUActivation.cs
@@ -6,7 +6,7 @@
using TorchSharp;
using static TorchSharp.torch;
-namespace Microsoft.ML.GenAI;
+namespace Microsoft.ML.GenAI.Core;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
internal class NewGELUActivation : torch.nn.Module
#pragma warning disable MSML_GeneralName // This name should be PascalCased
diff --git a/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs b/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs
index 268ac0a4a4..f399efe324 100644
--- a/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs
+++ b/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs
@@ -5,7 +5,7 @@
using Microsoft.ML.GenAI.Core;
using TorchSharp;
using static TorchSharp.torch;
-namespace Microsoft.ML.GenAI;
+namespace Microsoft.ML.GenAI.Core;
internal class QuantizedLinear : GenAILinear, IQuantizeModule
{
@@ -74,6 +74,7 @@ public void Int8()
this.register_buffer("scale", scale);
}
}
+
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override Tensor forward(Tensor input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs b/src/Microsoft.ML.GenAI.Core/Module/RMSNorm.cs
similarity index 92%
rename from src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs
rename to src/Microsoft.ML.GenAI.Core/Module/RMSNorm.cs
index e8c847268e..b9555cd845 100644
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RMSNorm.cs
+++ b/src/Microsoft.ML.GenAI.Core/Module/RMSNorm.cs
@@ -11,10 +11,10 @@
using TorchSharp.Modules;
using static TorchSharp.torch;
-namespace Microsoft.ML.GenAI.Phi.Module;
+namespace Microsoft.ML.GenAI.Core;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
-internal class Phi3RMSNorm : torch.nn.Module
+internal class RMSNorm : torch.nn.Module
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
private readonly int _dim;
@@ -23,11 +23,11 @@ internal class Phi3RMSNorm : torch.nn.Module
private readonly Parameter weight;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
- public Phi3RMSNorm(
+ public RMSNorm(
int hiddenSize,
float eps = 1e-6f,
ScalarType dtype = ScalarType.Float32)
- : base(nameof(Phi3RMSNorm))
+ : base(nameof(RMSNorm))
{
this._dim = hiddenSize;
this._eps = eps;
diff --git a/src/Microsoft.ML.GenAI.Core/Module/RotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Core/Module/RotaryEmbedding.cs
new file mode 100644
index 0000000000..8e06c838d5
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Core/Module/RotaryEmbedding.cs
@@ -0,0 +1,125 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Text.Json.Serialization;
+using TorchSharp;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.Core;
+
+public class RopeScalingConfig
+{
+ public RopeScalingConfig()
+ {
+ this.Factor = 1.0f;
+ this.LowFreqFactor = 1.0f;
+ this.HighFreqFactor = 1.0f;
+ this.OriginalMaxPositionEmbeddings = 8192;
+ this.RopeType = "default";
+ }
+
+ [JsonPropertyName("factor")]
+ public float Factor { get; set; }
+
+ [JsonPropertyName("low_freq_factor")]
+ public float LowFreqFactor { get; set; }
+
+ [JsonPropertyName("high_freq_factor")]
+ public float HighFreqFactor { get; set; }
+
+ [JsonPropertyName("original_max_position_embeddings")]
+ public int OriginalMaxPositionEmbeddings { get; set; }
+
+ [JsonPropertyName("rope_type")]
+ public string RopeType { get; set; }
+}
+
+
+internal class RotaryEmbeddingInput
+{
+ public RotaryEmbeddingInput(Tensor input, Tensor positionIds, int? seqLen = null)
+ {
+ Input = input;
+ PositionIds = positionIds;
+ SeqLen = seqLen;
+ }
+
+ public Tensor Input { get; set; }
+
+ public Tensor PositionIds { get; set; }
+
+ public int? SeqLen { get; set; }
+}
+
+internal class RotaryEmbeddingOutput
+{
+ public RotaryEmbeddingOutput(Tensor cos, Tensor sin)
+ {
+ Cos = cos;
+ Sin = sin;
+ }
+
+ public Tensor Cos { get; set; }
+
+ public Tensor Sin { get; set; }
+}
+
+
+internal class RotaryEmbedding : nn.Module<
+ RotaryEmbeddingInput,
+ RotaryEmbeddingOutput>
+{
+ private readonly double _base;
+ private readonly int _maxPositionEmbeddings;
+ private readonly int _dim;
+
+ public RotaryEmbedding(double baseValue, int maxPositionEmbeddings, int dim)
+ : this(baseValue, dim, new RopeScalingConfig() { RopeType = "default", OriginalMaxPositionEmbeddings = maxPositionEmbeddings })
+ {
+ }
+
+ public RotaryEmbedding(double baseValue, int dim, RopeScalingConfig config)
+ : base(nameof(RotaryEmbedding))
+ {
+ _base = baseValue;
+ _maxPositionEmbeddings = config.OriginalMaxPositionEmbeddings;
+ _dim = dim;
+
+ if (config.RopeType == "default")
+ {
+ var thetaNumerator = torch.arange(0, _dim, 2, dtype: ScalarType.Int64).to(torch.float32);
+ this.register_buffer("inv_freq", torch.pow(baseValue, -1.0f * (thetaNumerator / dim)), persistent: false);
+ }
+ else
+ {
+ throw new NotImplementedException("Rope type not implemented");
+ }
+ }
+
+ public int Dim => _dim;
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override RotaryEmbeddingOutput forward(RotaryEmbeddingInput input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ var x = input.Input;
+ var positionIds = input.PositionIds;
+ var seqLen = input.SeqLen;
+ // TODO
+ // can be calculated once and cached
+ var invFreq = this.get_buffer("inv_freq").to(x.device);
+ var invFreqExpanded = invFreq.unsqueeze(0).unsqueeze(-1);
+ invFreqExpanded = invFreqExpanded.expand(new long[] { positionIds.shape[0], -1, 1 });
+ var positionIdsExpanded = positionIds.unsqueeze(1).to(torch.float32);
+ var freqs = invFreqExpanded * positionIdsExpanded;
+ freqs = freqs.transpose(1, 2);
+ var emb = torch.cat([freqs, freqs], dim: -1);
+
+ var cos = torch.cos(emb);
+ var sin = torch.sin(emb);
+
+ return new(cos.to_type(x.dtype), sin.to_type(x.dtype));
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelInput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs
similarity index 96%
rename from src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelInput.cs
rename to src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs
index 49fcfef627..eaf94f2a80 100644
--- a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelInput.cs
+++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs
@@ -6,7 +6,7 @@
namespace Microsoft.ML.GenAI.Core;
-public class CasualLMModelInput
+public class CausalLMModelInput
{
internal static class Defaults
{
@@ -18,7 +18,7 @@ internal static class Defaults
internal const bool OutputAttentions = false;
internal const bool OutputHiddenStates = false;
}
- public CasualLMModelInput(
+ public CausalLMModelInput(
Tensor inputIds,
Tensor? attentionMask = Defaults.AttentionMask,
Tensor? positionIds = Defaults.PositionIds,
diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs
similarity index 94%
rename from src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs
rename to src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs
index afaa84e778..c10b68e60f 100644
--- a/src/Microsoft.ML.GenAI.Core/Pipeline/CasualLMModelOutput.cs
+++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs
@@ -6,7 +6,7 @@
namespace Microsoft.ML.GenAI.Core;
-public class CasualLMModelOutput
+public class CausalLMModelOutput
{
internal static class Defaults
{
@@ -15,7 +15,7 @@ internal static class Defaults
internal const Tensor[]? Attentions = null;
internal const IKVCache? Cache = null;
}
- public CasualLMModelOutput(
+ public CausalLMModelOutput(
Tensor lastHiddenState,
Tensor? logits = Defaults.Logits,
Tensor[]? allHiddenStates = Defaults.AllHiddenStates,
diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
index 9decdd3207..7ecb64f761 100644
--- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
+++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
@@ -16,7 +16,7 @@ namespace Microsoft.ML.GenAI.Core;
public interface ICausalLMPipeline : ICausalLMPipeline
where TTokenizer : Tokenizer
- where TModel : nn.Module
+ where TModel : nn.Module
{
TTokenizer Tokenizer { get; }
@@ -58,7 +58,7 @@ IEnumerable GenerateStreaming(
public class CausalLMPipeline : CausalLMPipeline, ICausalLMPipeline
where TTokenizer : Tokenizer
- where TModel : nn.Module
+ where TModel : nn.Module
{
public CausalLMPipeline(
TTokenizer tokenizer,
@@ -86,7 +86,7 @@ internal static class Defaults
public CausalLMPipeline(
Tokenizer tokenizer,
- nn.Module model,
+ nn.Module model,
string device = Defaults.Device)
{
this.Tokenizer = tokenizer;
@@ -106,7 +106,7 @@ private protected CausalLMPipeline()
public Tokenizer Tokenizer { get; }
- public nn.Module Model { get; }
+ public nn.Module Model { get; }
public Device Device { get; }
@@ -134,7 +134,7 @@ private protected CausalLMPipeline()
var cache = new DynamicKVCache();
if (promptLength == totalLen)
{
- var input = new CasualLMModelInput(inputIds, attentionMask, pastKeyValuesLength: 0)
+ var input = new CausalLMModelInput(inputIds, attentionMask, pastKeyValuesLength: 0)
{
OverrideCache = cache,
};
@@ -143,7 +143,7 @@ private protected CausalLMPipeline()
}
for (var curPos = promptLength; curPos != totalLen; curPos++)
{
- var input = new CasualLMModelInput(inputIds[.., prevPos..curPos], attentionMask[.., prevPos..curPos], pastKeyValuesLength: prevPos)
+ var input = new CausalLMModelInput(inputIds[.., prevPos..curPos], attentionMask[.., prevPos..curPos], pastKeyValuesLength: prevPos)
{
OverrideCache = cache,
};
diff --git a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs
new file mode 100644
index 0000000000..a0720694c3
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs
@@ -0,0 +1,27 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using AutoGen.Core;
+using Microsoft.SemanticKernel.ChatCompletion;
+
+namespace Microsoft.ML.GenAI.Core;
+
+public interface ISemanticKernelChatTemplateBuilder
+{
+ string BuildPrompt(ChatHistory chatHistory);
+}
+
+public interface IAutoGenChatTemplateBuilder
+{
+ string BuildPrompt(IEnumerable messages);
+}
+
+public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
+{
+}
diff --git a/src/Microsoft.ML.GenAI.Core/Utils.cs b/src/Microsoft.ML.GenAI.Core/Utils.cs
index 2f46e7d43d..e4e1078d2e 100644
--- a/src/Microsoft.ML.GenAI.Core/Utils.cs
+++ b/src/Microsoft.ML.GenAI.Core/Utils.cs
@@ -145,7 +145,7 @@ public static Tensor Phi2RepeatKV(Tensor x, int nRep)
.view(batchSize, seqLen, nKVHeads * nRep, headDim);
}
- public static Tensor Phi3RepeatKV(Tensor x, int nRep)
+ public static Tensor RepeatKV(Tensor x, int nRep)
{
var batchSize = x.shape[0];
var nKVHeads = x.shape[1];
@@ -156,9 +156,9 @@ public static Tensor Phi3RepeatKV(Tensor x, int nRep)
return x;
}
- return x.unsqueeze(3)
+ return x.unsqueeze(2)
.expand(batchSize, nKVHeads, nRep, seqLen, headDim)
- .view(batchSize, nKVHeads * nRep, seqLen, headDim);
+ .reshape(batchSize, nKVHeads * nRep, seqLen, headDim);
}
}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs
new file mode 100644
index 0000000000..b96dee6dba
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs
@@ -0,0 +1,90 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Text;
+using AutoGen.Core;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+
+namespace Microsoft.ML.GenAI.LLaMA;
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+{
+ private const char Newline = '\n';
+
+ public string BuildPrompt(IEnumerable messages)
+ {
+ var availableRoles = new[] { Role.System, Role.User, Role.Assistant };
+ if (messages.Any(m => m.GetContent() is null))
+ {
+ throw new InvalidOperationException("Please provide a message with content.");
+ }
+
+ if (messages.Any(m => m.GetRole() is null || availableRoles.Contains(m.GetRole()!.Value) == false))
+ {
+ throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
+ }
+
+ // construct template based on instruction from
+ // https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py#L280
+
+ var sb = new StringBuilder();
+ sb.Append("<|begin_of_text|>");
+ foreach (var message in messages)
+ {
+ var role = message.GetRole()!.Value;
+ var content = message.GetContent()!;
+ sb.Append(message switch
+ {
+ _ when message.GetRole() == Role.System => $"<|start_header_id|>system<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
+ _ when message.GetRole() == Role.User => $"<|start_header_id|>user<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
+ _ when message.GetRole() == Role.Assistant => $"<|start_header_id|>assistant<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
+ _ => throw new InvalidOperationException("Invalid role.")
+ });
+ }
+
+ sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
+ var input = sb.ToString();
+
+ return input;
+ }
+
+ public string BuildPrompt(ChatHistory chatHistory)
+ {
+ // build prompt from chat history
+ var sb = new StringBuilder();
+
+ sb.Append("<|begin_of_text|>");
+ foreach (var message in chatHistory)
+ {
+ foreach (var item in message.Items)
+ {
+ if (item is not TextContent textContent)
+ {
+ throw new NotSupportedException($"Only text content is supported, but got {item.GetType().Name}");
+ }
+
+ var text = textContent.Text?.Trim() ?? string.Empty;
+
+ var prompt = message.Role switch
+ {
+ _ when message.Role == AuthorRole.System => $"<|start_header_id|>system<|end_header_id|>{Newline}{text}<|eot_id|>{Newline}",
+ _ when message.Role == AuthorRole.User => $"<|start_header_id|>user<|end_header_id|>{Newline}{text}<|eot_id|>{Newline}",
+ _ when message.Role == AuthorRole.Assistant => $"<|start_header_id|>assistant<|end_header_id|>{Newline}{text}<|eot_id|>{Newline}",
+ _ => throw new NotSupportedException($"Unsupported role {message.Role}")
+ };
+
+ sb.Append(prompt);
+ }
+ }
+
+ sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
+
+ return sb.ToString();
+ }
+
+ public static Llama3_1ChatTemplateBuilder Instance { get; } = new Llama3_1ChatTemplateBuilder();
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs
new file mode 100644
index 0000000000..5deabd6df2
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs
@@ -0,0 +1,89 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+using AutoGen.Core;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.Tokenizers;
+
+namespace Microsoft.ML.GenAI.LLaMA;
+
+public class LlamaCausalLMAgent : IStreamingAgent
+{
+ private const char Newline = '\n';
+ private readonly ICausalLMPipeline _pipeline;
+ private readonly string? _systemMessage;
+ private readonly IAutoGenChatTemplateBuilder _templateBuilder;
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// pipeline
+ /// agent name
+ /// system message.
+ /// the template builder to build chat prompt. If the value is null, would be used.
+ public LlamaCausalLMAgent(
+ ICausalLMPipeline pipeline,
+ string name,
+ string? systemMessage = "you are a helpful assistant",
+ IAutoGenChatTemplateBuilder? templateBuilder = null)
+ {
+ this.Name = name;
+ this._pipeline = pipeline;
+ this._systemMessage = systemMessage;
+ this._templateBuilder = templateBuilder ?? Llama3_1ChatTemplateBuilder.Instance;
+ }
+
+ public string Name { get; }
+
+ public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ if (_systemMessage != null)
+ {
+ var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
+ messages = messages.Prepend(systemMessage);
+ }
+ var input = _templateBuilder.BuildPrompt(messages);
+ var maxLen = options?.MaxToken ?? 1024;
+ var temperature = options?.Temperature ?? 0.7f;
+ var stopTokenSequence = options?.StopSequence ?? [];
+ stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray();
+
+ var output = _pipeline.Generate(
+ input,
+ maxLen: maxLen,
+ temperature: temperature,
+ stopSequences: stopTokenSequence) ?? throw new InvalidOperationException("Failed to generate a reply.");
+
+ return Task.FromResult(new TextMessage(Role.Assistant, output, from: this.Name));
+ }
+
+#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
+ public async IAsyncEnumerable GenerateStreamingReplyAsync(
+#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ if (_systemMessage != null)
+ {
+ var systemMessage = new TextMessage(Role.System, _systemMessage, from: this.Name);
+ messages = messages.Prepend(systemMessage);
+ }
+ var input = _templateBuilder.BuildPrompt(messages);
+ var maxLen = options?.MaxToken ?? 1024;
+ var temperature = options?.Temperature ?? 0.7f;
+ var stopTokenSequence = options?.StopSequence ?? [];
+ stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray();
+
+ foreach (var output in _pipeline.GenerateStreaming(
+ input,
+ maxLen: maxLen,
+ temperature: temperature,
+ stopSequences: stopTokenSequence))
+ {
+ yield return new TextMessageUpdate(Role.Assistant, output, from: this.Name);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaChatCompletionService.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaChatCompletionService.cs
new file mode 100644
index 0000000000..3e43e7eefb
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaChatCompletionService.cs
@@ -0,0 +1,55 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.Tokenizers;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+
+namespace Microsoft.ML.GenAI.LLaMA;
+
+public class LlamaChatCompletionService : IChatCompletionService
+{
+ private readonly ICausalLMPipeline _pipeline;
+ private readonly LlamaTextCompletionService _textGenerationService;
+ private readonly ISemanticKernelChatTemplateBuilder _templateBuilder;
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// pipeline
+ /// The template builder to use for generating chat prompts, if not provided, will be used.
+ public LlamaChatCompletionService(ICausalLMPipeline pipeline, ISemanticKernelChatTemplateBuilder? templateBuilder = null)
+ {
+ _pipeline = pipeline;
+ _textGenerationService = new LlamaTextCompletionService(pipeline);
+ _templateBuilder = templateBuilder ?? Llama3_1ChatTemplateBuilder.Instance;
+ }
+
+ public IReadOnlyDictionary Attributes => _textGenerationService.Attributes;
+
+ public async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
+ {
+ var prompt = _templateBuilder.BuildPrompt(chatHistory);
+ var replies = await _textGenerationService.GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken);
+
+ return replies.Select(reply => new ChatMessageContent(AuthorRole.Assistant, reply.Text)).ToList();
+ }
+
+ public async IAsyncEnumerable GetStreamingChatMessageContentsAsync(
+ ChatHistory chatHistory,
+ PromptExecutionSettings? executionSettings = null,
+ Kernel? kernel = null,
+ [EnumeratorCancellation]
+ CancellationToken cancellationToken = default)
+ {
+ var prompt = _templateBuilder.BuildPrompt(chatHistory);
+
+ await foreach (var reply in _textGenerationService.GetStreamingTextContentsAsync(prompt, executionSettings, kernel, cancellationToken))
+ {
+ yield return new StreamingChatMessageContent(AuthorRole.Assistant, reply.Text);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaConfig.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaConfig.cs
new file mode 100644
index 0000000000..a8a6985ee8
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaConfig.cs
@@ -0,0 +1,124 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using System.Threading.Tasks;
+using Microsoft.ML.GenAI.Core;
+using TorchSharp;
+
+namespace Microsoft.ML.GenAI.LLaMA;
+
+public class LlamaConfig
+{
+ public LlamaConfig()
+ {
+ this.AttentionBias = false;
+ this.AttentionDropout = 0.0;
+ this.HiddenAct = "silu";
+ this.HiddenSize = 4096;
+ this.InitializerRange = 0.02;
+ this.IntermediateSize = 14336;
+ this.MaxPositionEmbeddings = 131072;
+ this.MlpBias = false;
+ this.NumAttentionHeads = 32;
+ this.NumHiddenLayers = 32;
+ this.NumKeyValueHeads = 8;
+ this.PretrainingTp = 1;
+ this.RmsNormEps = 1e-05f;
+ this.RopeScaling = new RopeScalingConfig();
+ this.RopeTheta = 500000.0;
+ this.TieWordEmbeddings = false;
+ this.VocabSize = 128256;
+ this.AttnImplementation = "eager";
+ this.DType = torch.ScalarType.BFloat16;
+ }
+
+ static LlamaConfig()
+ {
+#pragma warning disable MSML_ParameterLocalVarName // Parameter or local variable name not standard
+ var llama3_1_8b_content = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.LLaMA.Resource.Config.meta-llama-3.1-8B-Instruct.json");
+ var llama3_1_70b_content = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.LLaMA.Resource.Config.meta-llama-3.1-70B-Instruct.json");
+ var llama3_1_405b_content = Utils.GetEmbeddedResource("Microsoft.ML.GenAI.LLaMA.Resource.Config.meta-llama-3.1-405B-Instruct.json");
+#pragma warning restore MSML_ParameterLocalVarName // Parameter or local variable name not standard
+
+ Llama3_1_8B_Instruct = JsonSerializer.Deserialize(llama3_1_8b_content) ?? throw new ArgumentNullException(nameof(llama3_1_8b_content));
+ Llama3_1_70B_Instruct = JsonSerializer.Deserialize(llama3_1_70b_content) ?? throw new ArgumentNullException(nameof(llama3_1_70b_content));
+ Llama3_1_405B_Instruct = JsonSerializer.Deserialize(llama3_1_405b_content) ?? throw new ArgumentNullException(nameof(llama3_1_405b_content));
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ ///
+ /// The llama-3.1-8B-Instruct configuration created from https://huggingface.co/meta-llama/Meta-Llama-3.1-8B.
+ ///
+ public static LlamaConfig Llama3_1_8B_Instruct { get; }
+
+ ///
+ /// The llama-3.1-70B-Instruct configuration created from https://huggingface.co/meta-llama/Meta-Llama-3.1-70B.
+ ///
+ public static LlamaConfig Llama3_1_70B_Instruct { get; }
+
+ ///
+ /// The llama-3.1-405B-Instruct configuration created from https://huggingface.co/meta-llama/Meta-Llama-3.1-405B.
+ ///
+ public static LlamaConfig Llama3_1_405B_Instruct { get; }
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+
+ [JsonPropertyName("attention_bias")]
+ public bool AttentionBias { get; set; }
+
+ [JsonPropertyName("attention_dropout")]
+ public double AttentionDropout { get; set; }
+
+ [JsonPropertyName("hidden_act")]
+ public string HiddenAct { get; set; }
+
+ [JsonPropertyName("hidden_size")]
+ public int HiddenSize { get; set; }
+
+ [JsonPropertyName("initializer_range")]
+ public double InitializerRange { get; set; }
+
+ [JsonPropertyName("intermediate_size")]
+ public int IntermediateSize { get; set; }
+
+ [JsonPropertyName("max_position_embeddings")]
+ public int MaxPositionEmbeddings { get; set; }
+
+ [JsonPropertyName("mlp_bias")]
+ public bool MlpBias { get; set; }
+
+ [JsonPropertyName("num_attention_heads")]
+ public int NumAttentionHeads { get; set; }
+
+ [JsonPropertyName("num_hidden_layers")]
+ public int NumHiddenLayers { get; set; }
+
+ [JsonPropertyName("num_key_value_heads")]
+ public int NumKeyValueHeads { get; set; }
+
+ [JsonPropertyName("pretraining_tp")]
+ public int PretrainingTp { get; set; }
+
+ [JsonPropertyName("rms_norm_eps")]
+ public float RmsNormEps { get; set; }
+
+ public RopeScalingConfig RopeScaling { get; set; }
+
+ [JsonPropertyName("rope_theta")]
+ public double RopeTheta { get; set; }
+
+ [JsonPropertyName("tie_word_embeddings")]
+ public bool TieWordEmbeddings { get; set; }
+
+ [JsonPropertyName("vocab_size")]
+ public int VocabSize { get; set; }
+ public int? PadTokenId { get; set; }
+ public torch.ScalarType DType { get; set; }
+ public string AttnImplementation { get; set; }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs
new file mode 100644
index 0000000000..b7e038da1b
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs
@@ -0,0 +1,121 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Diagnostics;
+using System.Text.Json;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.Core.Extension;
+using Microsoft.ML.GenAI.LLaMA.Module;
+using TorchSharp;
+using TorchSharp.PyBridge;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.LLaMA;
+
+public class LlamaForCausalLM : nn.Module
+{
+ private readonly LlamaConfig _config;
+ private readonly int _vocabSize;
+
+#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly GenAILinear lm_head;
+ private readonly LlamaModel model;
+#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+
+ public LlamaForCausalLM(LlamaConfig config, string? device = null)
+ : base(nameof(LlamaForCausalLM))
+ {
+ _config = config;
+ _vocabSize = config.VocabSize;
+
+ model = new LlamaModel(config, device);
+ lm_head = new GenAILinear(config.HiddenSize, config.VocabSize, hasBias: false);
+
+ this.RegisterComponents();
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override CausalLMModelOutput forward(CausalLMModelInput input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ var outputs = this.model.forward(input);
+ var logits = this.lm_head.forward(outputs.LastHiddenState);
+ logits = logits.to_type(ScalarType.Float32);
+ outputs.Logits = logits;
+
+ return outputs;
+ }
+
+ public static LlamaForCausalLM FromPretrained(
+ string modelFolder,
+ string configName = "config.json",
+ string checkPointName = "model.safetensors.index.json",
+ ScalarType torchDtype = ScalarType.BFloat16,
+ string device = "cpu")
+ {
+ var config = Path.Join(modelFolder, configName);
+ var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
+ modelConfig.DType = torchDtype;
+ var model = new LlamaForCausalLM(modelConfig);
+
+ model.LoadSafeTensors(modelFolder, checkPointName);
+ model = model.to(device);
+
+ return model;
+ }
+
+ public static LlamaForCausalLM FromPretrained(
+ string modelFolder,
+ string configName = "config.json",
+ string checkPointName = "model.safetensors.index.json",
+ bool quantizeToInt8 = false,
+ bool quantizeToInt4 = false,
+ int layersOnTargetDevice = -1,
+ ScalarType torchDtype = ScalarType.BFloat16,
+ string targetDevice = "cuda")
+ {
+ if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
+ {
+ return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
+ }
+
+ var originalDefaultDevice = torch.get_default_device();
+ torch.set_default_device("meta");
+ var config = Path.Join(modelFolder, configName);
+ var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
+ modelConfig.DType = torchDtype;
+ var model = new LlamaForCausalLM(modelConfig);
+
+ if (quantizeToInt8)
+ {
+ model.ToInt8QuantizeModule();
+ }
+ else if (quantizeToInt4)
+ {
+ model.ToInt4QuantizeModule();
+ }
+
+ var deviceMap = model.InferDeviceMapForEachLayer(
+ [
+ KeyValuePair.Create(targetDevice, layersOnTargetDevice),
+ KeyValuePair.Create("cpu", -1)
+ ]);
+
+ torch.set_default_device("cpu");
+ model = new LlamaForCausalLM(modelConfig);
+
+ model.LoadSafeTensors(modelFolder, checkPointName);
+
+ model = model.ToDynamicLoadingModel(deviceMap, targetDevice);
+
+ torch.set_default_device(originalDefaultDevice);
+
+ return model;
+ }
+
+ public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
+ {
+ this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: true, useTqdm: false);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaTextCompletionService.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaTextCompletionService.cs
new file mode 100644
index 0000000000..5ac0a9afb9
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaTextCompletionService.cs
@@ -0,0 +1,77 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.Tokenizers;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.TextGeneration;
+
+namespace Microsoft.ML.GenAI.LLaMA;
+
+public class LlamaTextCompletionService : ITextGenerationService
+{
+ private readonly ICausalLMPipeline _pipeline;
+
+ public LlamaTextCompletionService(ICausalLMPipeline pipeline)
+ {
+ _pipeline = pipeline;
+ }
+
+ public IReadOnlyDictionary Attributes => new Dictionary()
+ {
+ { "temperature", null },
+ { "max_token", null },
+ { "stop_token_sequence", null },
+ { "top_p", null },
+ };
+
+#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
+ public async IAsyncEnumerable GetStreamingTextContentsAsync(
+#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
+ string prompt,
+ PromptExecutionSettings? executionSettings = null,
+ Kernel? kernel = null,
+ [EnumeratorCancellation]
+ CancellationToken cancellationToken = default)
+ {
+ var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f;
+ var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 100;
+ var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as string[] ?? Array.Empty();
+ var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f;
+ stopTokenSequence.Append("<|eot_id|>");
+
+ foreach (var item in _pipeline.GenerateStreaming(
+ prompt,
+ maxToken,
+ temperature,
+ topP,
+ stopTokenSequence))
+ {
+ yield return new StreamingTextContent(item);
+ }
+ }
+
+ public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
+ {
+ var temperature = executionSettings?.ExtensionData?["temperature"] as float? ?? 0.7f;
+ var maxToken = executionSettings?.ExtensionData?["max_token"] as int? ?? 512;
+ var stopTokenSequence = executionSettings?.ExtensionData?["stop_token_sequence"] as List ?? new List();
+ var topP = executionSettings?.ExtensionData?["top_p"] as float? ?? 0.9f;
+ stopTokenSequence.Add("<|eot_id|>");
+ var response = _pipeline.Generate(
+ prompt,
+ maxToken,
+ temperature,
+ stopSequences: stopTokenSequence.ToArray(),
+ topP: topP);
+
+ return Task.FromResult>([new TextContent(response)]);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs
new file mode 100644
index 0000000000..ea6f49edf7
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaTokenizerHelper.cs
@@ -0,0 +1,55 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Text.RegularExpressions;
+using System.Threading.Tasks;
+using Microsoft.ML.Tokenizers;
+
+namespace Microsoft.ML.GenAI.LLaMA;
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+public class LlamaTokenizerHelper
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+{
+ ///
+ /// https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/blob/main/tokenizer.json#pre_tokenizer.pretokenizers.pattern
+ ///
+ private const string _re = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
+
+ ///
+ /// https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/blob/main/tokenizer.json#added_tokens
+ ///
+ private static readonly Dictionary _specialTokens = new()
+ {
+ { "<|begin_of_text|>", 128000 },
+ { "<|end_of_text|>", 128001 },
+ { "<|finetune_right_pad_id|>", 128004 },
+ { "<|start_header_id|>", 128006 },
+ { "<|end_header_id|>", 128007 },
+ { "<|eom_id|>", 128008 },
+ { "<|eot_id|>", 128009 },
+ { "<|system|>", 32006 },
+ { "<|user|>", 32010 },
+ { "<|assistant|>", 32001 },
+ { "<|end|>", 32007 }
+ };
+
+ ///
+ /// Create from tokenizer model file.
+ ///
+ /// path to tokenizer model folder
+ /// tokenizer model file name
+ public static TiktokenTokenizer FromPretrained(
+ string modelWeightFolder,
+ string modelFile = "tokenizer.model")
+ {
+ var modelFilePath = Path.Join(modelWeightFolder, modelFile);
+ var preTokenizer = new TiktokenPreTokenizer(new Regex(_re), _specialTokens);
+ return TiktokenTokenizer.Create(File.OpenRead(modelFilePath), preTokenizer, normalizer: null, specialTokens: _specialTokens);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj b/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj
new file mode 100644
index 0000000000..5b0cb0acc0
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Microsoft.ML.GenAI.LLaMA.csproj
@@ -0,0 +1,24 @@
+
+
+
+ net6.0;net8.0
+ enable
+ enable
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaDecoderLayer.cs b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaDecoderLayer.cs
new file mode 100644
index 0000000000..0e3132f739
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaDecoderLayer.cs
@@ -0,0 +1,154 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.Core.Extension;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.LLaMA.Module;
+
+internal class DecoderLayerInput
+{
+ public DecoderLayerInput(
+ Tensor hiddenStates,
+ Tensor attentionMask,
+ Tensor positionIds,
+ RotaryEmbeddingOutput positionEmbeddings, // cos, sin
+ IKVCache? pastKeyValue = null,
+ bool outputAttentions = false)
+ {
+ this.HiddenStates = hiddenStates;
+ this.AttentionMask = attentionMask;
+ this.PositionIds = positionIds;
+ this.PastKeyValue = pastKeyValue;
+ this.OutputAttentions = outputAttentions;
+ this.PositionalEmbeddings = positionEmbeddings;
+ }
+
+ public Tensor HiddenStates { get; set; }
+
+ public Tensor AttentionMask { get; set; }
+
+ public Tensor PositionIds { get; set; }
+
+ public RotaryEmbeddingOutput PositionalEmbeddings { get; set; }
+
+ public IKVCache? PastKeyValue { get; set; }
+
+ public bool OutputAttentions { get; set; }
+}
+
+internal class DecoderLayerOutput
+{
+ public DecoderLayerOutput(
+ Tensor hiddenStates,
+ Tensor? attentions = null,
+ IKVCache? pastKeyValue = null)
+ {
+ this.HiddenStates = hiddenStates;
+ this.Attentions = attentions;
+ this.PastKeyValue = pastKeyValue;
+ }
+
+ public Tensor HiddenStates { get; set; }
+
+ public Tensor? Attentions { get; set; }
+
+ public IKVCache? PastKeyValue { get; set; }
+}
+internal class LlamaDecoderLayer : nn.Module, IDynamicLoadModule
+{
+ private readonly LlamaConfig _llamaConfig;
+ private readonly int _layerIndex;
+ private readonly int _hiddenSize;
+
+#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly LlamaMLP mlp;
+ private readonly Core.RMSNorm input_layernorm;
+ private readonly Core.RMSNorm post_attention_layernorm;
+ private readonly Attention self_attn;
+
+ public Action? LoadToDeviceFunc { get; set; }
+ public Action? UnloadFromDeviceFunc { get; set; }
+
+#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+
+ public LlamaDecoderLayer(LlamaConfig config, int layerIndex)
+ : base(nameof(LlamaDecoderLayer))
+ {
+ _llamaConfig = config;
+ _layerIndex = layerIndex;
+ _hiddenSize = config.HiddenSize;
+
+ this.self_attn = CreateAttention(config, layerIndex);
+ this.mlp = new LlamaMLP(config);
+ this.input_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType);
+ this.post_attention_layernorm = new Core.RMSNorm(this._hiddenSize, eps: config.RmsNormEps, config.DType);
+ }
+
+ private Attention CreateAttention(LlamaConfig config, int layerIndex)
+ {
+ var headDim = config.HiddenSize / config.NumAttentionHeads;
+ return new Attention(
+ attentionDropout: config.AttentionDropout,
+ hiddenSize: config.HiddenSize,
+ numHeads: config.NumAttentionHeads,
+ headDim: headDim,
+ numKeyValueHeads: config.NumKeyValueHeads,
+ numKeyValueGroups: config.NumAttentionHeads / config.NumKeyValueHeads,
+ maxPositionEmbeddings: config.MaxPositionEmbeddings,
+ originalMaxPositionEmbeddings: config.MaxPositionEmbeddings,
+ layerIdx: layerIndex,
+ useQkvProj: false,
+ dtype: config.DType,
+ attentionBias: config.AttentionBias);
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override DecoderLayerOutput forward(DecoderLayerInput input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ if (LoadToDeviceFunc != null)
+ {
+ LoadToDeviceFunc(this);
+ }
+
+ using var disposeScope = NewDisposeScope();
+ var residual = input.HiddenStates;
+ var hiddenStates = this.input_layernorm.forward(input.HiddenStates);
+
+ var selfAttnInput = new AttentionInput(
+ hiddenStates: hiddenStates,
+ attentionMask: input.AttentionMask,
+ positionIds: input.PositionIds,
+ cache: input.PastKeyValue,
+ positionalEmbeddings: input.PositionalEmbeddings,
+ outputAttentions: input.OutputAttentions);
+
+ var selfAttnOutput = this.self_attn.forward(selfAttnInput);
+
+ hiddenStates = residual + selfAttnOutput.HiddenStates;
+
+ // Fully connected
+ residual = hiddenStates;
+ hiddenStates = this.post_attention_layernorm.forward(hiddenStates);
+ hiddenStates = this.mlp.forward(hiddenStates);
+ hiddenStates = residual + hiddenStates;
+
+ if (UnloadFromDeviceFunc != null)
+ {
+ UnloadFromDeviceFunc(this);
+ }
+
+ return new DecoderLayerOutput(
+ hiddenStates: hiddenStates.MoveToOuterDisposeScope(),
+ attentions: input.OutputAttentions ? selfAttnOutput.Attentions?.MoveToOuterDisposeScope() : null,
+ pastKeyValue: selfAttnOutput.Cache);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaMLP.cs b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaMLP.cs
new file mode 100644
index 0000000000..cbc841f144
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaMLP.cs
@@ -0,0 +1,61 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.LLaMA;
+using TorchSharp;
+using TorchSharp.Modules;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.LLaMA.Module;
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+internal class LlamaMLP : torch.nn.Module
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+{
+ private readonly int _pretrainingTp;
+ private readonly int _intermediateSize;
+ private readonly int _hiddenSize;
+ private readonly bool _hasBias;
+#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly QuantizedLinear gate_proj;
+ private readonly QuantizedLinear up_proj;
+ private readonly QuantizedLinear down_proj;
+ private readonly torch.nn.Module activation_fn;
+#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+
+ public LlamaMLP(LlamaConfig config)
+ : base(nameof(LlamaMLP))
+ {
+ this._hiddenSize = config.HiddenSize;
+ this._intermediateSize = config.IntermediateSize;
+ this._hasBias = config.MlpBias;
+ this._pretrainingTp = config.PretrainingTp;
+ var hiddenAct = config.HiddenAct;
+ this.gate_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: this._hasBias, dtype: config.DType);
+ this.up_proj = new QuantizedLinear(this._hiddenSize, this._intermediateSize, hasBias: this._hasBias, dtype: config.DType);
+ this.down_proj = new QuantizedLinear(this._intermediateSize, this._hiddenSize, hasBias: this._hasBias, dtype: config.DType);
+ this.RegisterComponents();
+ this.activation_fn = Core.Utils.GetActivation(hiddenAct);
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override Tensor forward(Tensor input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ if (this._pretrainingTp > 1)
+ {
+ throw new NotImplementedException("PretrainingTp > 1 is not supported yet.");
+ }
+
+ using var input1 = this.gate_proj.forward(input);
+ using var input2 = this.activation_fn.forward(input1);
+ using var input3 = input2 * this.up_proj.forward(input);
+ return this.down_proj.forward(input3);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs
new file mode 100644
index 0000000000..1ba7820a9f
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs
@@ -0,0 +1,154 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.Core.Extension;
+using TorchSharp;
+using TorchSharp.Modules;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.LLaMA.Module;
+
+internal class LlamaModel : nn.Module
+{
+ private readonly LlamaConfig _config;
+ private readonly int? _paddingIdx;
+ private readonly int _vocabSize;
+ private IKVCache _cache;
+#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly Embedding embed_tokens;
+ private readonly ModuleList layers;
+ private readonly RMSNorm norm;
+#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly nn.Module _rotaryEmb;
+
+
+ public LlamaModel(LlamaConfig config, string? device = null)
+ : base(nameof(LlamaModel))
+ {
+ this._config = config;
+ this._paddingIdx = config.PadTokenId;
+ this._vocabSize = config.VocabSize;
+ var headDim = config.HiddenSize / config.NumAttentionHeads;
+ this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, padding_idx: this._paddingIdx, dtype: config.DType, device: device);
+ this.layers = new ModuleList();
+
+ for (int i = 0; i < config.NumHiddenLayers; i++)
+ {
+ this.layers.Add(new LlamaDecoderLayer(config, i));
+ }
+ this.norm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
+ this._cache = new DynamicKVCache();
+ this.RegisterComponents();
+ this._rotaryEmb = config.RopeScaling switch
+ {
+ null => new RotaryEmbedding(config.RopeTheta, config.MaxPositionEmbeddings, headDim),
+ _ => new RotaryEmbedding(config.RopeTheta, headDim, config.RopeScaling),
+ };
+ }
+
+#pragma warning disable MSML_GeneralName // This name should be PascalCased
+ public override CausalLMModelOutput forward(CausalLMModelInput input)
+#pragma warning restore MSML_GeneralName // This name should be PascalCased
+ {
+ if (input.OverrideCache is not null)
+ {
+ this._cache = input.OverrideCache;
+ }
+
+ var outputAttentions = input.OutputAttentions;
+ var outputHiddenStates = input.OutputHiddenStates;
+ var attentionMask = input.AttentionMask;
+ Device device;
+ var inputIds = input.InputIds;
+ var positionIds = input.PositionIds;
+ var inputsEmbeds = input.InputEmbeddings;
+ int batchSize;
+ int seqLength;
+ if (inputIds is not null && inputsEmbeds is not null)
+ {
+ throw new ArgumentException("Only one of input_ids or inputs_embeds may be set");
+ }
+ else if (inputIds is not null)
+ {
+ batchSize = inputIds.IntShape()[0];
+ seqLength = inputIds.IntShape()[1];
+ inputsEmbeds = this.embed_tokens.forward(inputIds);
+ device = inputIds.device;
+ }
+ else if (inputsEmbeds is not null)
+ {
+ batchSize = inputsEmbeds.IntShape()[0];
+ seqLength = inputsEmbeds.IntShape()[1];
+ device = inputsEmbeds.device;
+ }
+ else
+ {
+ throw new ArgumentException("Either input_ids or inputs_embeds must be set");
+ }
+
+ var pastKeyValuesLength = input.PastKeyValuesLength;
+
+ if (positionIds is null)
+ {
+ positionIds = torch.arange(pastKeyValuesLength, seqLength + pastKeyValuesLength, device: device);
+ positionIds = positionIds.unsqueeze(0).view(-1, seqLength);
+ }
+ else
+ {
+ positionIds = ((long)positionIds.view(-1, seqLength));
+ }
+
+ if (this._config.AttnImplementation == "flash_attention_2")
+ {
+ throw new NotImplementedException();
+ }
+ else
+ {
+ // the following behavior of creating 4d causal mask doesn't match python's, remember to look into it when there's time.
+ attentionMask = AttentionMaskConverter.Create4DCausalAttentionMask(attentionMask, [batchSize, seqLength], inputsEmbeds.dtype, device, pastKeyValuesLength);
+ }
+
+ var hiddenStates = inputsEmbeds;
+
+ var allHiddenStates = new List();
+ var allAttentions = new List();
+
+ var embOutput = this._rotaryEmb.forward(new RotaryEmbeddingInput(hiddenStates, positionIds, pastKeyValuesLength));
+ foreach (var layer in this.layers)
+ {
+ if (outputHiddenStates)
+ {
+ allHiddenStates.Add(hiddenStates);
+ }
+
+ var decoderInput = new DecoderLayerInput(
+ hiddenStates: hiddenStates,
+ attentionMask: attentionMask!,
+ positionIds: positionIds,
+ pastKeyValue: this._cache,
+ positionEmbeddings: embOutput,
+ outputAttentions: outputAttentions);
+ var layerOutput = layer.forward(decoderInput);
+ hiddenStates = layerOutput.HiddenStates;
+ if (outputAttentions && layerOutput.Attentions is not null)
+ {
+ allAttentions.Add(layerOutput.Attentions);
+ }
+ }
+
+ hiddenStates = this.norm.forward(hiddenStates);
+ if (outputHiddenStates)
+ {
+ allHiddenStates.Add(hiddenStates);
+ }
+
+ return new CausalLMModelOutput(lastHiddenState: hiddenStates, allHiddenStates: allHiddenStates.ToArray(), attentions: allAttentions.ToArray(), cache: this._cache);
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-405B-Instruct.json b/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-405B-Instruct.json
new file mode 100644
index 0000000000..373b94f4f6
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-405B-Instruct.json
@@ -0,0 +1,32 @@
+{
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 128000,
+ "eos_token_id": [
+ 128001,
+ 128008,
+ 128009
+ ],
+ "hidden_act": "silu",
+ "hidden_size": 16384,
+ "initializer_range": 0.02,
+ "intermediate_size": 53248,
+ "max_position_embeddings": 131072,
+ "mlp_bias": false,
+ "num_attention_heads": 128,
+ "num_hidden_layers": 126,
+ "num_key_value_heads": 8,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "factor": 8.0,
+ "high_freq_factor": 4.0,
+ "low_freq_factor": 1.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3"
+ },
+ "rope_theta": 500000.0,
+ "tie_word_embeddings": false,
+ "use_cache": true,
+ "vocab_size": 128256
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-70B-Instruct.json b/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-70B-Instruct.json
new file mode 100644
index 0000000000..2cd3ad59ac
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-70B-Instruct.json
@@ -0,0 +1,32 @@
+{
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 128000,
+ "eos_token_id": [
+ 128001,
+ 128008,
+ 128009
+ ],
+ "hidden_act": "silu",
+ "hidden_size": 8192,
+ "initializer_range": 0.02,
+ "intermediate_size": 28672,
+ "max_position_embeddings": 131072,
+ "mlp_bias": false,
+ "num_attention_heads": 64,
+ "num_hidden_layers": 80,
+ "num_key_value_heads": 8,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "factor": 8.0,
+ "low_freq_factor": 1.0,
+ "high_freq_factor": 4.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3"
+ },
+ "rope_theta": 500000.0,
+ "tie_word_embeddings": false,
+ "use_cache": true,
+ "vocab_size": 128256
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-8B-Instruct.json b/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-8B-Instruct.json
new file mode 100644
index 0000000000..750f5671d6
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Resource/Config/meta-llama-3.1-8B-Instruct.json
@@ -0,0 +1,33 @@
+{
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 128000,
+ "eos_token_id": [
+ 128001,
+ 128008,
+ 128009
+ ],
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 14336,
+ "max_position_embeddings": 131072,
+ "mlp_bias": false,
+ "model_type": "llama",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 8,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "factor": 8.0,
+ "low_freq_factor": 1.0,
+ "high_freq_factor": 4.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3"
+ },
+ "rope_theta": 500000.0,
+ "tie_word_embeddings": false,
+ "use_cache": true,
+ "vocab_size": 128256
+}
diff --git a/src/Microsoft.ML.GenAI.LLaMA/Utils.cs b/src/Microsoft.ML.GenAI.LLaMA/Utils.cs
new file mode 100644
index 0000000000..622aba9fff
--- /dev/null
+++ b/src/Microsoft.ML.GenAI.LLaMA/Utils.cs
@@ -0,0 +1,27 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Reflection;
+using TorchSharp;
+using static TorchSharp.torch;
+
+namespace Microsoft.ML.GenAI.LLaMA;
+
+internal static class Utils
+{
+ public static string GetEmbeddedResource(string resourceName)
+ {
+ // read file content from embedded resource
+ var assembly = Assembly.GetExecutingAssembly();
+ var resourceStream = assembly.GetManifestResourceStream(resourceName);
+
+ if (resourceStream == null)
+ {
+ throw new ArgumentException("Resource not found", resourceName);
+ }
+
+ using var reader = new System.IO.StreamReader(resourceStream);
+ return reader.ReadToEnd();
+ }
+}
diff --git a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj
index a9556443dd..e8605ba403 100644
--- a/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj
+++ b/src/Microsoft.ML.GenAI.Phi/Microsoft.ML.GenAI.Phi.csproj
@@ -7,19 +7,10 @@
-
-
-
-
-
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs
index 918ae7c99b..fe0021980f 100644
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2Attention.cs
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System.Diagnostics.Contracts;
+using Microsoft.ML.GenAI.Core;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs
index 384d012e22..42bd892588 100644
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2MLP.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using Microsoft.ML.GenAI.Core;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs
index 399cd25646..35b9313b33 100644
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3DecoderLayer.cs
@@ -8,6 +8,7 @@
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
+using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp.Modules;
using static TorchSharp.torch;
@@ -19,6 +20,7 @@ public Phi3DecoderLayerInput(
Tensor hiddenStates,
Tensor attentionMask,
Tensor positionIds,
+ RotaryEmbeddingOutput positionalEmbeddings, // cos, sin
IKVCache? pastKeyValue = null,
bool outputAttentions = false)
{
@@ -26,6 +28,7 @@ public Phi3DecoderLayerInput(
this.AttentionMask = attentionMask;
this.PositionIds = positionIds;
this.PastKeyValue = pastKeyValue;
+ this.PositionalEmbeddings = positionalEmbeddings;
this.OutputAttentions = outputAttentions;
}
@@ -35,6 +38,8 @@ public Phi3DecoderLayerInput(
public Tensor PositionIds { get; set; }
+ public RotaryEmbeddingOutput PositionalEmbeddings { get; set; } // cos, sin
+
public IKVCache? PastKeyValue { get; set; }
public bool OutputAttentions { get; set; }
@@ -63,12 +68,12 @@ internal class Phi3DecoderLayer : nn.Module self_attn;
+ private readonly nn.Module self_attn;
private readonly Phi3MLP mlp;
- private readonly Phi3RMSNorm input_layernorm;
+ private readonly RMSNorm input_layernorm;
private readonly Dropout resid_attn_dropout;
private readonly Dropout resid_mlp_dropout;
- private readonly Phi3RMSNorm post_attention_layernorm;
+ private readonly RMSNorm post_attention_layernorm;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
public Phi3DecoderLayer(Phi3Config config, int layerIdx)
@@ -77,7 +82,7 @@ public Phi3DecoderLayer(Phi3Config config, int layerIdx)
this._config = config;
if (config.AttnImplementation == "eager")
{
- this.self_attn = new Phi3Attention(config, layerIdx);
+ this.self_attn = this.CreateAttentionFromConfig(config, layerIdx);
}
else
{
@@ -85,11 +90,11 @@ public Phi3DecoderLayer(Phi3Config config, int layerIdx)
}
this.mlp = new Phi3MLP(config);
- this.input_layernorm = new Phi3RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
+ this.input_layernorm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
this.resid_attn_dropout = nn.Dropout(config.ResidPdrop);
this.resid_mlp_dropout = nn.Dropout(config.ResidPdrop);
- this.post_attention_layernorm = new Phi3RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
+ this.post_attention_layernorm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
}
public Action? LoadToDeviceFunc { get; set; }
@@ -109,7 +114,13 @@ public override Phi3DecoderLayerOutput forward(Phi3DecoderLayerInput input)
var residual = input.HiddenStates;
hiddenStates = this.input_layernorm.forward(hiddenStates);
- var attentionInput = new Phi3AttentionInput(hiddenStates, input.PositionIds, input.AttentionMask, input.PastKeyValue, input.OutputAttentions);
+ var attentionInput = new AttentionInput(
+ hiddenStates: hiddenStates,
+ positionIds: input.PositionIds,
+ attentionMask: input.AttentionMask,
+ cache: input.PastKeyValue,
+ positionalEmbeddings: input.PositionalEmbeddings,
+ outputAttentions: input.OutputAttentions);
var output = this.self_attn.forward(attentionInput);
var attnOutputs = output.HiddenStates;
var selfAttnWeights = output.Attentions;
@@ -126,4 +137,21 @@ public override Phi3DecoderLayerOutput forward(Phi3DecoderLayerInput input)
}
return new Phi3DecoderLayerOutput(hiddenStates.MoveToOuterDisposeScope(), selfAttnWeights?.MoveToOuterDisposeScope(), presentKeyValue);
}
+
+ private Attention CreateAttentionFromConfig(Phi3Config config, int layerIdx)
+ {
+ var headDim = config.HiddenSize / config.NumAttentionHeads;
+ return new Attention(
+ attentionDropout: config.AttentionDropout,
+ hiddenSize: config.HiddenSize,
+ numHeads: config.NumAttentionHeads,
+ headDim: headDim,
+ numKeyValueHeads: config.NumKeyValueHeads ?? throw new ArgumentException("num_key_value_heads must be specified"),
+ numKeyValueGroups: config.NumAttentionHeads / config.NumKeyValueHeads ?? throw new ArgumentException("num_key_value_heads must be specified"),
+ maxPositionEmbeddings: config.MaxPositionEmbeddings,
+ originalMaxPositionEmbeddings: config.OriginalMaxPositionEmbeddings,
+ layerIdx: layerIdx,
+ useQkvProj: true,
+ dtype: config.DType);
+ }
}
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs
index 745c000800..65c0413e39 100644
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3MLP.cs
@@ -7,6 +7,7 @@
using System.Linq;
using System.Text;
using System.Threading.Tasks;
+using Microsoft.ML.GenAI.Core;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
@@ -33,7 +34,7 @@ public Phi3MLP(int hiddenSize, int intermediateSize, string hiddenAct, ScalarTyp
this.gate_up_proj = new QuantizedLinear(hiddenSize, 2 * intermediateSize, hasBias: false, dtype: dtype);
this.down_proj = new QuantizedLinear(intermediateSize, hiddenSize, hasBias: false, dtype: dtype);
this.RegisterComponents();
- this.activation_fn = Utils.GetActivation(hiddenAct);
+ this.activation_fn = Core.Utils.GetActivation(hiddenAct);
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Model.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Model.cs
index 9f9f0a17ab..e873ddd9d8 100644
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3Model.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3Model.cs
@@ -5,11 +5,12 @@
using Microsoft.ML.GenAI.Core;
using TorchSharp;
using TorchSharp.Modules;
+using Microsoft.ML.GenAI.Core.Extension;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Phi.Module;
-internal class Phi3Model : nn.Module
+internal class Phi3Model : nn.Module
{
private readonly Phi3Config _config;
private readonly int _paddingIdx;
@@ -19,8 +20,9 @@ internal class Phi3Model : nn.Module
private readonly Embedding embed_tokens;
private readonly Dropout embed_dropout;
private readonly ModuleList layers;
- private readonly Phi3RMSNorm norm;
+ private readonly RMSNorm norm;
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
+ private readonly nn.Module _rotaryEmb;
public Phi3Model(Phi3Config config)
: base(nameof(Phi3Model))
@@ -28,6 +30,7 @@ public Phi3Model(Phi3Config config)
this._config = config;
this._paddingIdx = config.PadTokenId ?? 32000;
this._vocabSize = config.VocabSize;
+ var headDim = config.HiddenSize / config.NumAttentionHeads;
this.embed_tokens = nn.Embedding(config.VocabSize, config.HiddenSize, padding_idx: this._paddingIdx, dtype: config.DType);
this.embed_dropout = nn.Dropout(config.EmbdPdrop);
@@ -37,12 +40,18 @@ public Phi3Model(Phi3Config config)
{
this.layers.Add(new Phi3DecoderLayer(config, i));
}
- this.norm = new Phi3RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
+ this.norm = new RMSNorm(config.HiddenSize, config.RmsNormEps, config.DType);
this._cache = new DynamicKVCache();
this.RegisterComponents();
+
+ this._rotaryEmb = config.RopeScaling switch
+ {
+ null => new RotaryEmbedding(config.RopeTheta, config.MaxPositionEmbeddings, headDim),
+ _ => new Phi3SuScaledRotaryEmbedding(headDim, config),
+ };
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
- public override CasualLMModelOutput forward(CasualLMModelInput input)
+ public override CausalLMModelOutput forward(CausalLMModelInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
if (input.OverrideCache is not null)
@@ -103,18 +112,22 @@ public override CasualLMModelOutput forward(CasualLMModelInput input)
}
var hiddenStates = inputsEmbeds;
-
+ var positionEmbeddings = this._rotaryEmb.forward(new RotaryEmbeddingInput(hiddenStates, positionIds, seqLength));
var allHiddenStates = new List();
var allAttentions = new List();
-
foreach (var layer in this.layers)
{
if (outputHiddenStates)
{
allHiddenStates.Add(hiddenStates);
}
-
- var decoderInput = new Phi3DecoderLayerInput(hiddenStates, attentionMask!, positionIds, this._cache, outputAttentions);
+ var decoderInput = new Phi3DecoderLayerInput(
+ hiddenStates: hiddenStates,
+ attentionMask: attentionMask!,
+ positionIds: positionIds,
+ pastKeyValue: this._cache,
+ positionalEmbeddings: positionEmbeddings,
+ outputAttentions: outputAttentions);
var layerOutput = layer.forward(decoderInput);
hiddenStates = layerOutput.HiddenStates;
if (outputAttentions && layerOutput.Attentions is not null)
@@ -129,6 +142,6 @@ public override CasualLMModelOutput forward(CasualLMModelInput input)
allHiddenStates.Add(hiddenStates);
}
- return new CasualLMModelOutput(lastHiddenState: hiddenStates, allHiddenStates: allHiddenStates.ToArray(), attentions: allAttentions.ToArray(), cache: this._cache);
+ return new CausalLMModelOutput(lastHiddenState: hiddenStates, allHiddenStates: allHiddenStates.ToArray(), attentions: allAttentions.ToArray(), cache: this._cache);
}
}
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs
deleted file mode 100644
index 9b04a301d6..0000000000
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3RotaryEmbedding.cs
+++ /dev/null
@@ -1,81 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using TorchSharp;
-using static TorchSharp.torch;
-
-namespace Microsoft.ML.GenAI.Phi.Module;
-internal class Phi3RotaryEmbeddingInput
-{
- public Phi3RotaryEmbeddingInput(Tensor input, Tensor positionIds, int? seqLen = null)
- {
- Input = input;
- PositionIds = positionIds;
- SeqLen = seqLen;
- }
-
- public Tensor Input { get; set; }
-
- public Tensor PositionIds { get; set; }
-
- public int? SeqLen { get; set; }
-}
-
-internal class Phi3RotaryEmbeddingOutput
-{
- public Phi3RotaryEmbeddingOutput(Tensor cos, Tensor sin)
- {
- Cos = cos;
- Sin = sin;
- }
-
- public Tensor Cos { get; set; }
-
- public Tensor Sin { get; set; }
-}
-
-
-internal class Phi3RotaryEmbedding : nn.Module<
- Phi3RotaryEmbeddingInput,
- Phi3RotaryEmbeddingOutput>
-{
- private readonly double _base;
- private readonly int _maxPositionEmbeddings;
- private readonly int _dim;
-
- public Phi3RotaryEmbedding(double baseValue, int maxPositionEmbeddings, int dim)
- : base(nameof(Phi3RotaryEmbedding))
- {
- _base = baseValue;
- _maxPositionEmbeddings = maxPositionEmbeddings;
- _dim = dim;
- var thetaNumerator = torch.arange(0, _dim, 2, dtype: ScalarType.Int64).to(torch.float32);
- this.register_buffer("inv_freq", torch.pow(baseValue, -1.0f * (thetaNumerator / dim)), persistent: false);
- }
-
- public int Dim => _dim;
-
-#pragma warning disable MSML_GeneralName // This name should be PascalCased
- public override Phi3RotaryEmbeddingOutput forward(Phi3RotaryEmbeddingInput input)
-#pragma warning restore MSML_GeneralName // This name should be PascalCased
- {
- var x = input.Input;
- var positionIds = input.PositionIds;
- var seqLen = input.SeqLen;
- // TODO
- // can be calculated once and cached
- var invFreq = this.get_buffer("inv_freq").to(x.device);
- var invFreqExpanded = invFreq.unsqueeze(0).unsqueeze(-1);
- invFreqExpanded = invFreqExpanded.expand(new long[] { positionIds.shape[0], -1, 1 });
- var positionIdsExpanded = positionIds.unsqueeze(1).to(torch.float32);
- var freqs = invFreqExpanded * positionIdsExpanded;
- freqs = freqs.transpose(1, 2);
- var emb = torch.cat([freqs, freqs], dim: -1);
-
- var cos = torch.cos(emb);
- var sin = torch.sin(emb);
-
- return new(cos.to_type(x.dtype), sin.to_type(x.dtype));
- }
-}
diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs
index ce0e70b686..e2170493e4 100644
--- a/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi3SuScaledRotaryEmbedding.cs
@@ -8,12 +8,13 @@
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
+using Microsoft.ML.GenAI.Core;
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Phi.Module;
-internal class Phi3SuScaledRotaryEmbedding : Phi3RotaryEmbedding
+internal class Phi3SuScaledRotaryEmbedding : RotaryEmbedding
{
private readonly double[] _shortFactor;
private readonly double[] _longFactor;
@@ -35,7 +36,7 @@ public Phi3SuScaledRotaryEmbedding(int dim, Phi3Config config)
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
- public override Phi3RotaryEmbeddingOutput forward(Phi3RotaryEmbeddingInput input)
+ public override RotaryEmbeddingOutput forward(RotaryEmbeddingInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
var seqLen = (torch.max(input.PositionIds) + 1).ToInt32();
diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs
index efb3f23de9..1d49375565 100644
--- a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs
@@ -14,7 +14,7 @@
namespace Microsoft.ML.GenAI.Phi;
-public class Phi2ForCasualLM : nn.Module
+public class Phi2ForCasualLM : nn.Module
{
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format
private readonly Phi2Model model;
@@ -30,7 +30,7 @@ public Phi2ForCasualLM(Phi2Config config)
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
- public override CasualLMModelOutput forward(CasualLMModelInput input) // use_cache, output_attentions, output_hidden_states
+ public override CausalLMModelOutput forward(CausalLMModelInput input) // use_cache, output_attentions, output_hidden_states
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
var inputIds = input.InputIds;
@@ -44,7 +44,7 @@ public override CasualLMModelOutput forward(CasualLMModelInput input) // use_cac
var lmLogits = this.lm_head.forward(hiddenState);
- return new CasualLMModelOutput(lastHiddenState: hiddenState, logits: lmLogits);
+ return new CausalLMModelOutput(lastHiddenState: hiddenState, logits: lmLogits);
}
public static Phi2ForCasualLM FromPretrained(
diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs
index efe3089fdb..480e0d7e04 100644
--- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatCompletionService.cs
@@ -33,8 +33,8 @@ public async Task> GetChatMessageContentsAsync
CancellationToken cancellationToken = default)
{
var prompt = BuildPrompt(chatHistory);
- var reply = await _textGenerationService.GetTextContentAsync(prompt, executionSettings, kernel, cancellationToken);
- return [new ChatMessageContent(AuthorRole.Assistant, reply.Text)];
+ var replies = await _textGenerationService.GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken);
+ return replies.Select(reply => new ChatMessageContent(AuthorRole.Assistant, reply.Text)).ToList();
}
public async IAsyncEnumerable GetStreamingChatMessageContentsAsync(
diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs
index 41b2d970fd..c67741377e 100644
--- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs
@@ -17,7 +17,7 @@
namespace Microsoft.ML.GenAI.Phi;
-public class Phi3ForCasualLM : nn.Module
+public class Phi3ForCasualLM : nn.Module
{
private readonly Phi3Config _config;
@@ -37,7 +37,7 @@ public Phi3ForCasualLM(Phi3Config config)
}
#pragma warning disable MSML_GeneralName // This name should be PascalCased
- public override CasualLMModelOutput forward(CasualLMModelInput input)
+ public override CausalLMModelOutput forward(CausalLMModelInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
var outputs = this.model.forward(input);
diff --git a/src/Microsoft.ML.GenAI.Phi/README.md b/src/Microsoft.ML.GenAI.Phi/README.md
index 758a78ad47..2daf51039e 100644
--- a/src/Microsoft.ML.GenAI.Phi/README.md
+++ b/src/Microsoft.ML.GenAI.Phi/README.md
@@ -6,10 +6,10 @@ The following phi-models are supported and tested:
- [x] [Phi-2](https://huggingface.co/microsoft/phi-2)
- [x] [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [x] [Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)
+- [x] [Phi-3-medium-4k-instruct](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct)
+- [x] [Phi-3-medium-128k-instruct](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct)
- [ ] [Phi-3-small-8k-instruct](https://huggingface.co/microsoft/Phi-3-small-8k-instruct)
- [ ] [Phi-3-small-128k-instruct](https://huggingface.co/microsoft/Phi-3-small-128k-instruct)
-- [ ] [Phi-3-medium-4k-instruct](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct)
-- [ ] [Phi-3-medium-128k-instruct](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct)
- [ ] [Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-large-4k-instruct)
## Getting Started with Semantic Kernel
diff --git a/src/Microsoft.ML.GenAI.Phi/Utils.cs b/src/Microsoft.ML.GenAI.Phi/Utils.cs
index 4591d94f14..aa5a71719e 100644
--- a/src/Microsoft.ML.GenAI.Phi/Utils.cs
+++ b/src/Microsoft.ML.GenAI.Phi/Utils.cs
@@ -130,18 +130,6 @@ public static (Tensor, Tensor) ApplyRotaryPosEmb(Tensor q, Tensor k, Tensor cos,
return (qEmbed, kEmbed);
}
- public static Module GetActivation(string actFn)
- {
- return actFn switch
- {
- "silu" => nn.SiLU(),
- "relu" => nn.ReLU(),
- "gelu" => nn.GELU(),
- "tanh" => nn.Tanh(),
- "swish" => nn.SiLU(),
- _ => throw new ArgumentException("Invalid activation function", actFn),
- };
- }
public static Tensor Phi2RepeatKV(Tensor x, int nRep)
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt
new file mode 100644
index 0000000000..e4a2466fec
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromAutoGenChatHistory.approved.txt
@@ -0,0 +1,7 @@
+<|begin_of_text|><|start_header_id|>system<|end_header_id|>
+You are a helpful AI assistant.<|eot_id|>
+<|start_header_id|>user<|end_header_id|>
+Hello?<|eot_id|>
+<|start_header_id|>assistant<|end_header_id|>
+World!<|eot_id|>
+<|start_header_id|>assistant<|end_header_id|>
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromSemanticKernelChatHistory.approved.txt b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromSemanticKernelChatHistory.approved.txt
new file mode 100644
index 0000000000..e4a2466fec
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.ItBuildChatTemplateFromSemanticKernelChatHistory.approved.txt
@@ -0,0 +1,7 @@
+<|begin_of_text|><|start_header_id|>system<|end_header_id|>
+You are a helpful AI assistant.<|eot_id|>
+<|start_header_id|>user<|end_header_id|>
+Hello?<|eot_id|>
+<|start_header_id|>assistant<|end_header_id|>
+World!<|eot_id|>
+<|start_header_id|>assistant<|end_header_id|>
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_405b_ShapeTest.approved.txt b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_405b_ShapeTest.approved.txt
new file mode 100644
index 0000000000..6b8d7749dc
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_405b_ShapeTest.approved.txt
@@ -0,0 +1,1137 @@
+0: lm_head.weight shape: [128256, 16384]
+1: model.embed_tokens.weight shape: [128256, 16384]
+2: model.layers.0.input_layernorm.weight shape: [16384]
+3: model.layers.0.mlp.down_proj.weight shape: [16384, 53248]
+4: model.layers.0.mlp.gate_proj.weight shape: [53248, 16384]
+5: model.layers.0.mlp.up_proj.weight shape: [53248, 16384]
+6: model.layers.0.post_attention_layernorm.weight shape: [16384]
+7: model.layers.0.self_attn.k_proj.weight shape: [1024, 16384]
+8: model.layers.0.self_attn.o_proj.weight shape: [16384, 16384]
+9: model.layers.0.self_attn.q_proj.weight shape: [16384, 16384]
+10: model.layers.0.self_attn.v_proj.weight shape: [1024, 16384]
+11: model.layers.1.input_layernorm.weight shape: [16384]
+12: model.layers.1.mlp.down_proj.weight shape: [16384, 53248]
+13: model.layers.1.mlp.gate_proj.weight shape: [53248, 16384]
+14: model.layers.1.mlp.up_proj.weight shape: [53248, 16384]
+15: model.layers.1.post_attention_layernorm.weight shape: [16384]
+16: model.layers.1.self_attn.k_proj.weight shape: [1024, 16384]
+17: model.layers.1.self_attn.o_proj.weight shape: [16384, 16384]
+18: model.layers.1.self_attn.q_proj.weight shape: [16384, 16384]
+19: model.layers.1.self_attn.v_proj.weight shape: [1024, 16384]
+20: model.layers.10.input_layernorm.weight shape: [16384]
+21: model.layers.10.mlp.down_proj.weight shape: [16384, 53248]
+22: model.layers.10.mlp.gate_proj.weight shape: [53248, 16384]
+23: model.layers.10.mlp.up_proj.weight shape: [53248, 16384]
+24: model.layers.10.post_attention_layernorm.weight shape: [16384]
+25: model.layers.10.self_attn.k_proj.weight shape: [1024, 16384]
+26: model.layers.10.self_attn.o_proj.weight shape: [16384, 16384]
+27: model.layers.10.self_attn.q_proj.weight shape: [16384, 16384]
+28: model.layers.10.self_attn.v_proj.weight shape: [1024, 16384]
+29: model.layers.100.input_layernorm.weight shape: [16384]
+30: model.layers.100.mlp.down_proj.weight shape: [16384, 53248]
+31: model.layers.100.mlp.gate_proj.weight shape: [53248, 16384]
+32: model.layers.100.mlp.up_proj.weight shape: [53248, 16384]
+33: model.layers.100.post_attention_layernorm.weight shape: [16384]
+34: model.layers.100.self_attn.k_proj.weight shape: [1024, 16384]
+35: model.layers.100.self_attn.o_proj.weight shape: [16384, 16384]
+36: model.layers.100.self_attn.q_proj.weight shape: [16384, 16384]
+37: model.layers.100.self_attn.v_proj.weight shape: [1024, 16384]
+38: model.layers.101.input_layernorm.weight shape: [16384]
+39: model.layers.101.mlp.down_proj.weight shape: [16384, 53248]
+40: model.layers.101.mlp.gate_proj.weight shape: [53248, 16384]
+41: model.layers.101.mlp.up_proj.weight shape: [53248, 16384]
+42: model.layers.101.post_attention_layernorm.weight shape: [16384]
+43: model.layers.101.self_attn.k_proj.weight shape: [1024, 16384]
+44: model.layers.101.self_attn.o_proj.weight shape: [16384, 16384]
+45: model.layers.101.self_attn.q_proj.weight shape: [16384, 16384]
+46: model.layers.101.self_attn.v_proj.weight shape: [1024, 16384]
+47: model.layers.102.input_layernorm.weight shape: [16384]
+48: model.layers.102.mlp.down_proj.weight shape: [16384, 53248]
+49: model.layers.102.mlp.gate_proj.weight shape: [53248, 16384]
+50: model.layers.102.mlp.up_proj.weight shape: [53248, 16384]
+51: model.layers.102.post_attention_layernorm.weight shape: [16384]
+52: model.layers.102.self_attn.k_proj.weight shape: [1024, 16384]
+53: model.layers.102.self_attn.o_proj.weight shape: [16384, 16384]
+54: model.layers.102.self_attn.q_proj.weight shape: [16384, 16384]
+55: model.layers.102.self_attn.v_proj.weight shape: [1024, 16384]
+56: model.layers.103.input_layernorm.weight shape: [16384]
+57: model.layers.103.mlp.down_proj.weight shape: [16384, 53248]
+58: model.layers.103.mlp.gate_proj.weight shape: [53248, 16384]
+59: model.layers.103.mlp.up_proj.weight shape: [53248, 16384]
+60: model.layers.103.post_attention_layernorm.weight shape: [16384]
+61: model.layers.103.self_attn.k_proj.weight shape: [1024, 16384]
+62: model.layers.103.self_attn.o_proj.weight shape: [16384, 16384]
+63: model.layers.103.self_attn.q_proj.weight shape: [16384, 16384]
+64: model.layers.103.self_attn.v_proj.weight shape: [1024, 16384]
+65: model.layers.104.input_layernorm.weight shape: [16384]
+66: model.layers.104.mlp.down_proj.weight shape: [16384, 53248]
+67: model.layers.104.mlp.gate_proj.weight shape: [53248, 16384]
+68: model.layers.104.mlp.up_proj.weight shape: [53248, 16384]
+69: model.layers.104.post_attention_layernorm.weight shape: [16384]
+70: model.layers.104.self_attn.k_proj.weight shape: [1024, 16384]
+71: model.layers.104.self_attn.o_proj.weight shape: [16384, 16384]
+72: model.layers.104.self_attn.q_proj.weight shape: [16384, 16384]
+73: model.layers.104.self_attn.v_proj.weight shape: [1024, 16384]
+74: model.layers.105.input_layernorm.weight shape: [16384]
+75: model.layers.105.mlp.down_proj.weight shape: [16384, 53248]
+76: model.layers.105.mlp.gate_proj.weight shape: [53248, 16384]
+77: model.layers.105.mlp.up_proj.weight shape: [53248, 16384]
+78: model.layers.105.post_attention_layernorm.weight shape: [16384]
+79: model.layers.105.self_attn.k_proj.weight shape: [1024, 16384]
+80: model.layers.105.self_attn.o_proj.weight shape: [16384, 16384]
+81: model.layers.105.self_attn.q_proj.weight shape: [16384, 16384]
+82: model.layers.105.self_attn.v_proj.weight shape: [1024, 16384]
+83: model.layers.106.input_layernorm.weight shape: [16384]
+84: model.layers.106.mlp.down_proj.weight shape: [16384, 53248]
+85: model.layers.106.mlp.gate_proj.weight shape: [53248, 16384]
+86: model.layers.106.mlp.up_proj.weight shape: [53248, 16384]
+87: model.layers.106.post_attention_layernorm.weight shape: [16384]
+88: model.layers.106.self_attn.k_proj.weight shape: [1024, 16384]
+89: model.layers.106.self_attn.o_proj.weight shape: [16384, 16384]
+90: model.layers.106.self_attn.q_proj.weight shape: [16384, 16384]
+91: model.layers.106.self_attn.v_proj.weight shape: [1024, 16384]
+92: model.layers.107.input_layernorm.weight shape: [16384]
+93: model.layers.107.mlp.down_proj.weight shape: [16384, 53248]
+94: model.layers.107.mlp.gate_proj.weight shape: [53248, 16384]
+95: model.layers.107.mlp.up_proj.weight shape: [53248, 16384]
+96: model.layers.107.post_attention_layernorm.weight shape: [16384]
+97: model.layers.107.self_attn.k_proj.weight shape: [1024, 16384]
+98: model.layers.107.self_attn.o_proj.weight shape: [16384, 16384]
+99: model.layers.107.self_attn.q_proj.weight shape: [16384, 16384]
+100: model.layers.107.self_attn.v_proj.weight shape: [1024, 16384]
+101: model.layers.108.input_layernorm.weight shape: [16384]
+102: model.layers.108.mlp.down_proj.weight shape: [16384, 53248]
+103: model.layers.108.mlp.gate_proj.weight shape: [53248, 16384]
+104: model.layers.108.mlp.up_proj.weight shape: [53248, 16384]
+105: model.layers.108.post_attention_layernorm.weight shape: [16384]
+106: model.layers.108.self_attn.k_proj.weight shape: [1024, 16384]
+107: model.layers.108.self_attn.o_proj.weight shape: [16384, 16384]
+108: model.layers.108.self_attn.q_proj.weight shape: [16384, 16384]
+109: model.layers.108.self_attn.v_proj.weight shape: [1024, 16384]
+110: model.layers.109.input_layernorm.weight shape: [16384]
+111: model.layers.109.mlp.down_proj.weight shape: [16384, 53248]
+112: model.layers.109.mlp.gate_proj.weight shape: [53248, 16384]
+113: model.layers.109.mlp.up_proj.weight shape: [53248, 16384]
+114: model.layers.109.post_attention_layernorm.weight shape: [16384]
+115: model.layers.109.self_attn.k_proj.weight shape: [1024, 16384]
+116: model.layers.109.self_attn.o_proj.weight shape: [16384, 16384]
+117: model.layers.109.self_attn.q_proj.weight shape: [16384, 16384]
+118: model.layers.109.self_attn.v_proj.weight shape: [1024, 16384]
+119: model.layers.11.input_layernorm.weight shape: [16384]
+120: model.layers.11.mlp.down_proj.weight shape: [16384, 53248]
+121: model.layers.11.mlp.gate_proj.weight shape: [53248, 16384]
+122: model.layers.11.mlp.up_proj.weight shape: [53248, 16384]
+123: model.layers.11.post_attention_layernorm.weight shape: [16384]
+124: model.layers.11.self_attn.k_proj.weight shape: [1024, 16384]
+125: model.layers.11.self_attn.o_proj.weight shape: [16384, 16384]
+126: model.layers.11.self_attn.q_proj.weight shape: [16384, 16384]
+127: model.layers.11.self_attn.v_proj.weight shape: [1024, 16384]
+128: model.layers.110.input_layernorm.weight shape: [16384]
+129: model.layers.110.mlp.down_proj.weight shape: [16384, 53248]
+130: model.layers.110.mlp.gate_proj.weight shape: [53248, 16384]
+131: model.layers.110.mlp.up_proj.weight shape: [53248, 16384]
+132: model.layers.110.post_attention_layernorm.weight shape: [16384]
+133: model.layers.110.self_attn.k_proj.weight shape: [1024, 16384]
+134: model.layers.110.self_attn.o_proj.weight shape: [16384, 16384]
+135: model.layers.110.self_attn.q_proj.weight shape: [16384, 16384]
+136: model.layers.110.self_attn.v_proj.weight shape: [1024, 16384]
+137: model.layers.111.input_layernorm.weight shape: [16384]
+138: model.layers.111.mlp.down_proj.weight shape: [16384, 53248]
+139: model.layers.111.mlp.gate_proj.weight shape: [53248, 16384]
+140: model.layers.111.mlp.up_proj.weight shape: [53248, 16384]
+141: model.layers.111.post_attention_layernorm.weight shape: [16384]
+142: model.layers.111.self_attn.k_proj.weight shape: [1024, 16384]
+143: model.layers.111.self_attn.o_proj.weight shape: [16384, 16384]
+144: model.layers.111.self_attn.q_proj.weight shape: [16384, 16384]
+145: model.layers.111.self_attn.v_proj.weight shape: [1024, 16384]
+146: model.layers.112.input_layernorm.weight shape: [16384]
+147: model.layers.112.mlp.down_proj.weight shape: [16384, 53248]
+148: model.layers.112.mlp.gate_proj.weight shape: [53248, 16384]
+149: model.layers.112.mlp.up_proj.weight shape: [53248, 16384]
+150: model.layers.112.post_attention_layernorm.weight shape: [16384]
+151: model.layers.112.self_attn.k_proj.weight shape: [1024, 16384]
+152: model.layers.112.self_attn.o_proj.weight shape: [16384, 16384]
+153: model.layers.112.self_attn.q_proj.weight shape: [16384, 16384]
+154: model.layers.112.self_attn.v_proj.weight shape: [1024, 16384]
+155: model.layers.113.input_layernorm.weight shape: [16384]
+156: model.layers.113.mlp.down_proj.weight shape: [16384, 53248]
+157: model.layers.113.mlp.gate_proj.weight shape: [53248, 16384]
+158: model.layers.113.mlp.up_proj.weight shape: [53248, 16384]
+159: model.layers.113.post_attention_layernorm.weight shape: [16384]
+160: model.layers.113.self_attn.k_proj.weight shape: [1024, 16384]
+161: model.layers.113.self_attn.o_proj.weight shape: [16384, 16384]
+162: model.layers.113.self_attn.q_proj.weight shape: [16384, 16384]
+163: model.layers.113.self_attn.v_proj.weight shape: [1024, 16384]
+164: model.layers.114.input_layernorm.weight shape: [16384]
+165: model.layers.114.mlp.down_proj.weight shape: [16384, 53248]
+166: model.layers.114.mlp.gate_proj.weight shape: [53248, 16384]
+167: model.layers.114.mlp.up_proj.weight shape: [53248, 16384]
+168: model.layers.114.post_attention_layernorm.weight shape: [16384]
+169: model.layers.114.self_attn.k_proj.weight shape: [1024, 16384]
+170: model.layers.114.self_attn.o_proj.weight shape: [16384, 16384]
+171: model.layers.114.self_attn.q_proj.weight shape: [16384, 16384]
+172: model.layers.114.self_attn.v_proj.weight shape: [1024, 16384]
+173: model.layers.115.input_layernorm.weight shape: [16384]
+174: model.layers.115.mlp.down_proj.weight shape: [16384, 53248]
+175: model.layers.115.mlp.gate_proj.weight shape: [53248, 16384]
+176: model.layers.115.mlp.up_proj.weight shape: [53248, 16384]
+177: model.layers.115.post_attention_layernorm.weight shape: [16384]
+178: model.layers.115.self_attn.k_proj.weight shape: [1024, 16384]
+179: model.layers.115.self_attn.o_proj.weight shape: [16384, 16384]
+180: model.layers.115.self_attn.q_proj.weight shape: [16384, 16384]
+181: model.layers.115.self_attn.v_proj.weight shape: [1024, 16384]
+182: model.layers.116.input_layernorm.weight shape: [16384]
+183: model.layers.116.mlp.down_proj.weight shape: [16384, 53248]
+184: model.layers.116.mlp.gate_proj.weight shape: [53248, 16384]
+185: model.layers.116.mlp.up_proj.weight shape: [53248, 16384]
+186: model.layers.116.post_attention_layernorm.weight shape: [16384]
+187: model.layers.116.self_attn.k_proj.weight shape: [1024, 16384]
+188: model.layers.116.self_attn.o_proj.weight shape: [16384, 16384]
+189: model.layers.116.self_attn.q_proj.weight shape: [16384, 16384]
+190: model.layers.116.self_attn.v_proj.weight shape: [1024, 16384]
+191: model.layers.117.input_layernorm.weight shape: [16384]
+192: model.layers.117.mlp.down_proj.weight shape: [16384, 53248]
+193: model.layers.117.mlp.gate_proj.weight shape: [53248, 16384]
+194: model.layers.117.mlp.up_proj.weight shape: [53248, 16384]
+195: model.layers.117.post_attention_layernorm.weight shape: [16384]
+196: model.layers.117.self_attn.k_proj.weight shape: [1024, 16384]
+197: model.layers.117.self_attn.o_proj.weight shape: [16384, 16384]
+198: model.layers.117.self_attn.q_proj.weight shape: [16384, 16384]
+199: model.layers.117.self_attn.v_proj.weight shape: [1024, 16384]
+200: model.layers.118.input_layernorm.weight shape: [16384]
+201: model.layers.118.mlp.down_proj.weight shape: [16384, 53248]
+202: model.layers.118.mlp.gate_proj.weight shape: [53248, 16384]
+203: model.layers.118.mlp.up_proj.weight shape: [53248, 16384]
+204: model.layers.118.post_attention_layernorm.weight shape: [16384]
+205: model.layers.118.self_attn.k_proj.weight shape: [1024, 16384]
+206: model.layers.118.self_attn.o_proj.weight shape: [16384, 16384]
+207: model.layers.118.self_attn.q_proj.weight shape: [16384, 16384]
+208: model.layers.118.self_attn.v_proj.weight shape: [1024, 16384]
+209: model.layers.119.input_layernorm.weight shape: [16384]
+210: model.layers.119.mlp.down_proj.weight shape: [16384, 53248]
+211: model.layers.119.mlp.gate_proj.weight shape: [53248, 16384]
+212: model.layers.119.mlp.up_proj.weight shape: [53248, 16384]
+213: model.layers.119.post_attention_layernorm.weight shape: [16384]
+214: model.layers.119.self_attn.k_proj.weight shape: [1024, 16384]
+215: model.layers.119.self_attn.o_proj.weight shape: [16384, 16384]
+216: model.layers.119.self_attn.q_proj.weight shape: [16384, 16384]
+217: model.layers.119.self_attn.v_proj.weight shape: [1024, 16384]
+218: model.layers.12.input_layernorm.weight shape: [16384]
+219: model.layers.12.mlp.down_proj.weight shape: [16384, 53248]
+220: model.layers.12.mlp.gate_proj.weight shape: [53248, 16384]
+221: model.layers.12.mlp.up_proj.weight shape: [53248, 16384]
+222: model.layers.12.post_attention_layernorm.weight shape: [16384]
+223: model.layers.12.self_attn.k_proj.weight shape: [1024, 16384]
+224: model.layers.12.self_attn.o_proj.weight shape: [16384, 16384]
+225: model.layers.12.self_attn.q_proj.weight shape: [16384, 16384]
+226: model.layers.12.self_attn.v_proj.weight shape: [1024, 16384]
+227: model.layers.120.input_layernorm.weight shape: [16384]
+228: model.layers.120.mlp.down_proj.weight shape: [16384, 53248]
+229: model.layers.120.mlp.gate_proj.weight shape: [53248, 16384]
+230: model.layers.120.mlp.up_proj.weight shape: [53248, 16384]
+231: model.layers.120.post_attention_layernorm.weight shape: [16384]
+232: model.layers.120.self_attn.k_proj.weight shape: [1024, 16384]
+233: model.layers.120.self_attn.o_proj.weight shape: [16384, 16384]
+234: model.layers.120.self_attn.q_proj.weight shape: [16384, 16384]
+235: model.layers.120.self_attn.v_proj.weight shape: [1024, 16384]
+236: model.layers.121.input_layernorm.weight shape: [16384]
+237: model.layers.121.mlp.down_proj.weight shape: [16384, 53248]
+238: model.layers.121.mlp.gate_proj.weight shape: [53248, 16384]
+239: model.layers.121.mlp.up_proj.weight shape: [53248, 16384]
+240: model.layers.121.post_attention_layernorm.weight shape: [16384]
+241: model.layers.121.self_attn.k_proj.weight shape: [1024, 16384]
+242: model.layers.121.self_attn.o_proj.weight shape: [16384, 16384]
+243: model.layers.121.self_attn.q_proj.weight shape: [16384, 16384]
+244: model.layers.121.self_attn.v_proj.weight shape: [1024, 16384]
+245: model.layers.122.input_layernorm.weight shape: [16384]
+246: model.layers.122.mlp.down_proj.weight shape: [16384, 53248]
+247: model.layers.122.mlp.gate_proj.weight shape: [53248, 16384]
+248: model.layers.122.mlp.up_proj.weight shape: [53248, 16384]
+249: model.layers.122.post_attention_layernorm.weight shape: [16384]
+250: model.layers.122.self_attn.k_proj.weight shape: [1024, 16384]
+251: model.layers.122.self_attn.o_proj.weight shape: [16384, 16384]
+252: model.layers.122.self_attn.q_proj.weight shape: [16384, 16384]
+253: model.layers.122.self_attn.v_proj.weight shape: [1024, 16384]
+254: model.layers.123.input_layernorm.weight shape: [16384]
+255: model.layers.123.mlp.down_proj.weight shape: [16384, 53248]
+256: model.layers.123.mlp.gate_proj.weight shape: [53248, 16384]
+257: model.layers.123.mlp.up_proj.weight shape: [53248, 16384]
+258: model.layers.123.post_attention_layernorm.weight shape: [16384]
+259: model.layers.123.self_attn.k_proj.weight shape: [1024, 16384]
+260: model.layers.123.self_attn.o_proj.weight shape: [16384, 16384]
+261: model.layers.123.self_attn.q_proj.weight shape: [16384, 16384]
+262: model.layers.123.self_attn.v_proj.weight shape: [1024, 16384]
+263: model.layers.124.input_layernorm.weight shape: [16384]
+264: model.layers.124.mlp.down_proj.weight shape: [16384, 53248]
+265: model.layers.124.mlp.gate_proj.weight shape: [53248, 16384]
+266: model.layers.124.mlp.up_proj.weight shape: [53248, 16384]
+267: model.layers.124.post_attention_layernorm.weight shape: [16384]
+268: model.layers.124.self_attn.k_proj.weight shape: [1024, 16384]
+269: model.layers.124.self_attn.o_proj.weight shape: [16384, 16384]
+270: model.layers.124.self_attn.q_proj.weight shape: [16384, 16384]
+271: model.layers.124.self_attn.v_proj.weight shape: [1024, 16384]
+272: model.layers.125.input_layernorm.weight shape: [16384]
+273: model.layers.125.mlp.down_proj.weight shape: [16384, 53248]
+274: model.layers.125.mlp.gate_proj.weight shape: [53248, 16384]
+275: model.layers.125.mlp.up_proj.weight shape: [53248, 16384]
+276: model.layers.125.post_attention_layernorm.weight shape: [16384]
+277: model.layers.125.self_attn.k_proj.weight shape: [1024, 16384]
+278: model.layers.125.self_attn.o_proj.weight shape: [16384, 16384]
+279: model.layers.125.self_attn.q_proj.weight shape: [16384, 16384]
+280: model.layers.125.self_attn.v_proj.weight shape: [1024, 16384]
+281: model.layers.13.input_layernorm.weight shape: [16384]
+282: model.layers.13.mlp.down_proj.weight shape: [16384, 53248]
+283: model.layers.13.mlp.gate_proj.weight shape: [53248, 16384]
+284: model.layers.13.mlp.up_proj.weight shape: [53248, 16384]
+285: model.layers.13.post_attention_layernorm.weight shape: [16384]
+286: model.layers.13.self_attn.k_proj.weight shape: [1024, 16384]
+287: model.layers.13.self_attn.o_proj.weight shape: [16384, 16384]
+288: model.layers.13.self_attn.q_proj.weight shape: [16384, 16384]
+289: model.layers.13.self_attn.v_proj.weight shape: [1024, 16384]
+290: model.layers.14.input_layernorm.weight shape: [16384]
+291: model.layers.14.mlp.down_proj.weight shape: [16384, 53248]
+292: model.layers.14.mlp.gate_proj.weight shape: [53248, 16384]
+293: model.layers.14.mlp.up_proj.weight shape: [53248, 16384]
+294: model.layers.14.post_attention_layernorm.weight shape: [16384]
+295: model.layers.14.self_attn.k_proj.weight shape: [1024, 16384]
+296: model.layers.14.self_attn.o_proj.weight shape: [16384, 16384]
+297: model.layers.14.self_attn.q_proj.weight shape: [16384, 16384]
+298: model.layers.14.self_attn.v_proj.weight shape: [1024, 16384]
+299: model.layers.15.input_layernorm.weight shape: [16384]
+300: model.layers.15.mlp.down_proj.weight shape: [16384, 53248]
+301: model.layers.15.mlp.gate_proj.weight shape: [53248, 16384]
+302: model.layers.15.mlp.up_proj.weight shape: [53248, 16384]
+303: model.layers.15.post_attention_layernorm.weight shape: [16384]
+304: model.layers.15.self_attn.k_proj.weight shape: [1024, 16384]
+305: model.layers.15.self_attn.o_proj.weight shape: [16384, 16384]
+306: model.layers.15.self_attn.q_proj.weight shape: [16384, 16384]
+307: model.layers.15.self_attn.v_proj.weight shape: [1024, 16384]
+308: model.layers.16.input_layernorm.weight shape: [16384]
+309: model.layers.16.mlp.down_proj.weight shape: [16384, 53248]
+310: model.layers.16.mlp.gate_proj.weight shape: [53248, 16384]
+311: model.layers.16.mlp.up_proj.weight shape: [53248, 16384]
+312: model.layers.16.post_attention_layernorm.weight shape: [16384]
+313: model.layers.16.self_attn.k_proj.weight shape: [1024, 16384]
+314: model.layers.16.self_attn.o_proj.weight shape: [16384, 16384]
+315: model.layers.16.self_attn.q_proj.weight shape: [16384, 16384]
+316: model.layers.16.self_attn.v_proj.weight shape: [1024, 16384]
+317: model.layers.17.input_layernorm.weight shape: [16384]
+318: model.layers.17.mlp.down_proj.weight shape: [16384, 53248]
+319: model.layers.17.mlp.gate_proj.weight shape: [53248, 16384]
+320: model.layers.17.mlp.up_proj.weight shape: [53248, 16384]
+321: model.layers.17.post_attention_layernorm.weight shape: [16384]
+322: model.layers.17.self_attn.k_proj.weight shape: [1024, 16384]
+323: model.layers.17.self_attn.o_proj.weight shape: [16384, 16384]
+324: model.layers.17.self_attn.q_proj.weight shape: [16384, 16384]
+325: model.layers.17.self_attn.v_proj.weight shape: [1024, 16384]
+326: model.layers.18.input_layernorm.weight shape: [16384]
+327: model.layers.18.mlp.down_proj.weight shape: [16384, 53248]
+328: model.layers.18.mlp.gate_proj.weight shape: [53248, 16384]
+329: model.layers.18.mlp.up_proj.weight shape: [53248, 16384]
+330: model.layers.18.post_attention_layernorm.weight shape: [16384]
+331: model.layers.18.self_attn.k_proj.weight shape: [1024, 16384]
+332: model.layers.18.self_attn.o_proj.weight shape: [16384, 16384]
+333: model.layers.18.self_attn.q_proj.weight shape: [16384, 16384]
+334: model.layers.18.self_attn.v_proj.weight shape: [1024, 16384]
+335: model.layers.19.input_layernorm.weight shape: [16384]
+336: model.layers.19.mlp.down_proj.weight shape: [16384, 53248]
+337: model.layers.19.mlp.gate_proj.weight shape: [53248, 16384]
+338: model.layers.19.mlp.up_proj.weight shape: [53248, 16384]
+339: model.layers.19.post_attention_layernorm.weight shape: [16384]
+340: model.layers.19.self_attn.k_proj.weight shape: [1024, 16384]
+341: model.layers.19.self_attn.o_proj.weight shape: [16384, 16384]
+342: model.layers.19.self_attn.q_proj.weight shape: [16384, 16384]
+343: model.layers.19.self_attn.v_proj.weight shape: [1024, 16384]
+344: model.layers.2.input_layernorm.weight shape: [16384]
+345: model.layers.2.mlp.down_proj.weight shape: [16384, 53248]
+346: model.layers.2.mlp.gate_proj.weight shape: [53248, 16384]
+347: model.layers.2.mlp.up_proj.weight shape: [53248, 16384]
+348: model.layers.2.post_attention_layernorm.weight shape: [16384]
+349: model.layers.2.self_attn.k_proj.weight shape: [1024, 16384]
+350: model.layers.2.self_attn.o_proj.weight shape: [16384, 16384]
+351: model.layers.2.self_attn.q_proj.weight shape: [16384, 16384]
+352: model.layers.2.self_attn.v_proj.weight shape: [1024, 16384]
+353: model.layers.20.input_layernorm.weight shape: [16384]
+354: model.layers.20.mlp.down_proj.weight shape: [16384, 53248]
+355: model.layers.20.mlp.gate_proj.weight shape: [53248, 16384]
+356: model.layers.20.mlp.up_proj.weight shape: [53248, 16384]
+357: model.layers.20.post_attention_layernorm.weight shape: [16384]
+358: model.layers.20.self_attn.k_proj.weight shape: [1024, 16384]
+359: model.layers.20.self_attn.o_proj.weight shape: [16384, 16384]
+360: model.layers.20.self_attn.q_proj.weight shape: [16384, 16384]
+361: model.layers.20.self_attn.v_proj.weight shape: [1024, 16384]
+362: model.layers.21.input_layernorm.weight shape: [16384]
+363: model.layers.21.mlp.down_proj.weight shape: [16384, 53248]
+364: model.layers.21.mlp.gate_proj.weight shape: [53248, 16384]
+365: model.layers.21.mlp.up_proj.weight shape: [53248, 16384]
+366: model.layers.21.post_attention_layernorm.weight shape: [16384]
+367: model.layers.21.self_attn.k_proj.weight shape: [1024, 16384]
+368: model.layers.21.self_attn.o_proj.weight shape: [16384, 16384]
+369: model.layers.21.self_attn.q_proj.weight shape: [16384, 16384]
+370: model.layers.21.self_attn.v_proj.weight shape: [1024, 16384]
+371: model.layers.22.input_layernorm.weight shape: [16384]
+372: model.layers.22.mlp.down_proj.weight shape: [16384, 53248]
+373: model.layers.22.mlp.gate_proj.weight shape: [53248, 16384]
+374: model.layers.22.mlp.up_proj.weight shape: [53248, 16384]
+375: model.layers.22.post_attention_layernorm.weight shape: [16384]
+376: model.layers.22.self_attn.k_proj.weight shape: [1024, 16384]
+377: model.layers.22.self_attn.o_proj.weight shape: [16384, 16384]
+378: model.layers.22.self_attn.q_proj.weight shape: [16384, 16384]
+379: model.layers.22.self_attn.v_proj.weight shape: [1024, 16384]
+380: model.layers.23.input_layernorm.weight shape: [16384]
+381: model.layers.23.mlp.down_proj.weight shape: [16384, 53248]
+382: model.layers.23.mlp.gate_proj.weight shape: [53248, 16384]
+383: model.layers.23.mlp.up_proj.weight shape: [53248, 16384]
+384: model.layers.23.post_attention_layernorm.weight shape: [16384]
+385: model.layers.23.self_attn.k_proj.weight shape: [1024, 16384]
+386: model.layers.23.self_attn.o_proj.weight shape: [16384, 16384]
+387: model.layers.23.self_attn.q_proj.weight shape: [16384, 16384]
+388: model.layers.23.self_attn.v_proj.weight shape: [1024, 16384]
+389: model.layers.24.input_layernorm.weight shape: [16384]
+390: model.layers.24.mlp.down_proj.weight shape: [16384, 53248]
+391: model.layers.24.mlp.gate_proj.weight shape: [53248, 16384]
+392: model.layers.24.mlp.up_proj.weight shape: [53248, 16384]
+393: model.layers.24.post_attention_layernorm.weight shape: [16384]
+394: model.layers.24.self_attn.k_proj.weight shape: [1024, 16384]
+395: model.layers.24.self_attn.o_proj.weight shape: [16384, 16384]
+396: model.layers.24.self_attn.q_proj.weight shape: [16384, 16384]
+397: model.layers.24.self_attn.v_proj.weight shape: [1024, 16384]
+398: model.layers.25.input_layernorm.weight shape: [16384]
+399: model.layers.25.mlp.down_proj.weight shape: [16384, 53248]
+400: model.layers.25.mlp.gate_proj.weight shape: [53248, 16384]
+401: model.layers.25.mlp.up_proj.weight shape: [53248, 16384]
+402: model.layers.25.post_attention_layernorm.weight shape: [16384]
+403: model.layers.25.self_attn.k_proj.weight shape: [1024, 16384]
+404: model.layers.25.self_attn.o_proj.weight shape: [16384, 16384]
+405: model.layers.25.self_attn.q_proj.weight shape: [16384, 16384]
+406: model.layers.25.self_attn.v_proj.weight shape: [1024, 16384]
+407: model.layers.26.input_layernorm.weight shape: [16384]
+408: model.layers.26.mlp.down_proj.weight shape: [16384, 53248]
+409: model.layers.26.mlp.gate_proj.weight shape: [53248, 16384]
+410: model.layers.26.mlp.up_proj.weight shape: [53248, 16384]
+411: model.layers.26.post_attention_layernorm.weight shape: [16384]
+412: model.layers.26.self_attn.k_proj.weight shape: [1024, 16384]
+413: model.layers.26.self_attn.o_proj.weight shape: [16384, 16384]
+414: model.layers.26.self_attn.q_proj.weight shape: [16384, 16384]
+415: model.layers.26.self_attn.v_proj.weight shape: [1024, 16384]
+416: model.layers.27.input_layernorm.weight shape: [16384]
+417: model.layers.27.mlp.down_proj.weight shape: [16384, 53248]
+418: model.layers.27.mlp.gate_proj.weight shape: [53248, 16384]
+419: model.layers.27.mlp.up_proj.weight shape: [53248, 16384]
+420: model.layers.27.post_attention_layernorm.weight shape: [16384]
+421: model.layers.27.self_attn.k_proj.weight shape: [1024, 16384]
+422: model.layers.27.self_attn.o_proj.weight shape: [16384, 16384]
+423: model.layers.27.self_attn.q_proj.weight shape: [16384, 16384]
+424: model.layers.27.self_attn.v_proj.weight shape: [1024, 16384]
+425: model.layers.28.input_layernorm.weight shape: [16384]
+426: model.layers.28.mlp.down_proj.weight shape: [16384, 53248]
+427: model.layers.28.mlp.gate_proj.weight shape: [53248, 16384]
+428: model.layers.28.mlp.up_proj.weight shape: [53248, 16384]
+429: model.layers.28.post_attention_layernorm.weight shape: [16384]
+430: model.layers.28.self_attn.k_proj.weight shape: [1024, 16384]
+431: model.layers.28.self_attn.o_proj.weight shape: [16384, 16384]
+432: model.layers.28.self_attn.q_proj.weight shape: [16384, 16384]
+433: model.layers.28.self_attn.v_proj.weight shape: [1024, 16384]
+434: model.layers.29.input_layernorm.weight shape: [16384]
+435: model.layers.29.mlp.down_proj.weight shape: [16384, 53248]
+436: model.layers.29.mlp.gate_proj.weight shape: [53248, 16384]
+437: model.layers.29.mlp.up_proj.weight shape: [53248, 16384]
+438: model.layers.29.post_attention_layernorm.weight shape: [16384]
+439: model.layers.29.self_attn.k_proj.weight shape: [1024, 16384]
+440: model.layers.29.self_attn.o_proj.weight shape: [16384, 16384]
+441: model.layers.29.self_attn.q_proj.weight shape: [16384, 16384]
+442: model.layers.29.self_attn.v_proj.weight shape: [1024, 16384]
+443: model.layers.3.input_layernorm.weight shape: [16384]
+444: model.layers.3.mlp.down_proj.weight shape: [16384, 53248]
+445: model.layers.3.mlp.gate_proj.weight shape: [53248, 16384]
+446: model.layers.3.mlp.up_proj.weight shape: [53248, 16384]
+447: model.layers.3.post_attention_layernorm.weight shape: [16384]
+448: model.layers.3.self_attn.k_proj.weight shape: [1024, 16384]
+449: model.layers.3.self_attn.o_proj.weight shape: [16384, 16384]
+450: model.layers.3.self_attn.q_proj.weight shape: [16384, 16384]
+451: model.layers.3.self_attn.v_proj.weight shape: [1024, 16384]
+452: model.layers.30.input_layernorm.weight shape: [16384]
+453: model.layers.30.mlp.down_proj.weight shape: [16384, 53248]
+454: model.layers.30.mlp.gate_proj.weight shape: [53248, 16384]
+455: model.layers.30.mlp.up_proj.weight shape: [53248, 16384]
+456: model.layers.30.post_attention_layernorm.weight shape: [16384]
+457: model.layers.30.self_attn.k_proj.weight shape: [1024, 16384]
+458: model.layers.30.self_attn.o_proj.weight shape: [16384, 16384]
+459: model.layers.30.self_attn.q_proj.weight shape: [16384, 16384]
+460: model.layers.30.self_attn.v_proj.weight shape: [1024, 16384]
+461: model.layers.31.input_layernorm.weight shape: [16384]
+462: model.layers.31.mlp.down_proj.weight shape: [16384, 53248]
+463: model.layers.31.mlp.gate_proj.weight shape: [53248, 16384]
+464: model.layers.31.mlp.up_proj.weight shape: [53248, 16384]
+465: model.layers.31.post_attention_layernorm.weight shape: [16384]
+466: model.layers.31.self_attn.k_proj.weight shape: [1024, 16384]
+467: model.layers.31.self_attn.o_proj.weight shape: [16384, 16384]
+468: model.layers.31.self_attn.q_proj.weight shape: [16384, 16384]
+469: model.layers.31.self_attn.v_proj.weight shape: [1024, 16384]
+470: model.layers.32.input_layernorm.weight shape: [16384]
+471: model.layers.32.mlp.down_proj.weight shape: [16384, 53248]
+472: model.layers.32.mlp.gate_proj.weight shape: [53248, 16384]
+473: model.layers.32.mlp.up_proj.weight shape: [53248, 16384]
+474: model.layers.32.post_attention_layernorm.weight shape: [16384]
+475: model.layers.32.self_attn.k_proj.weight shape: [1024, 16384]
+476: model.layers.32.self_attn.o_proj.weight shape: [16384, 16384]
+477: model.layers.32.self_attn.q_proj.weight shape: [16384, 16384]
+478: model.layers.32.self_attn.v_proj.weight shape: [1024, 16384]
+479: model.layers.33.input_layernorm.weight shape: [16384]
+480: model.layers.33.mlp.down_proj.weight shape: [16384, 53248]
+481: model.layers.33.mlp.gate_proj.weight shape: [53248, 16384]
+482: model.layers.33.mlp.up_proj.weight shape: [53248, 16384]
+483: model.layers.33.post_attention_layernorm.weight shape: [16384]
+484: model.layers.33.self_attn.k_proj.weight shape: [1024, 16384]
+485: model.layers.33.self_attn.o_proj.weight shape: [16384, 16384]
+486: model.layers.33.self_attn.q_proj.weight shape: [16384, 16384]
+487: model.layers.33.self_attn.v_proj.weight shape: [1024, 16384]
+488: model.layers.34.input_layernorm.weight shape: [16384]
+489: model.layers.34.mlp.down_proj.weight shape: [16384, 53248]
+490: model.layers.34.mlp.gate_proj.weight shape: [53248, 16384]
+491: model.layers.34.mlp.up_proj.weight shape: [53248, 16384]
+492: model.layers.34.post_attention_layernorm.weight shape: [16384]
+493: model.layers.34.self_attn.k_proj.weight shape: [1024, 16384]
+494: model.layers.34.self_attn.o_proj.weight shape: [16384, 16384]
+495: model.layers.34.self_attn.q_proj.weight shape: [16384, 16384]
+496: model.layers.34.self_attn.v_proj.weight shape: [1024, 16384]
+497: model.layers.35.input_layernorm.weight shape: [16384]
+498: model.layers.35.mlp.down_proj.weight shape: [16384, 53248]
+499: model.layers.35.mlp.gate_proj.weight shape: [53248, 16384]
+500: model.layers.35.mlp.up_proj.weight shape: [53248, 16384]
+501: model.layers.35.post_attention_layernorm.weight shape: [16384]
+502: model.layers.35.self_attn.k_proj.weight shape: [1024, 16384]
+503: model.layers.35.self_attn.o_proj.weight shape: [16384, 16384]
+504: model.layers.35.self_attn.q_proj.weight shape: [16384, 16384]
+505: model.layers.35.self_attn.v_proj.weight shape: [1024, 16384]
+506: model.layers.36.input_layernorm.weight shape: [16384]
+507: model.layers.36.mlp.down_proj.weight shape: [16384, 53248]
+508: model.layers.36.mlp.gate_proj.weight shape: [53248, 16384]
+509: model.layers.36.mlp.up_proj.weight shape: [53248, 16384]
+510: model.layers.36.post_attention_layernorm.weight shape: [16384]
+511: model.layers.36.self_attn.k_proj.weight shape: [1024, 16384]
+512: model.layers.36.self_attn.o_proj.weight shape: [16384, 16384]
+513: model.layers.36.self_attn.q_proj.weight shape: [16384, 16384]
+514: model.layers.36.self_attn.v_proj.weight shape: [1024, 16384]
+515: model.layers.37.input_layernorm.weight shape: [16384]
+516: model.layers.37.mlp.down_proj.weight shape: [16384, 53248]
+517: model.layers.37.mlp.gate_proj.weight shape: [53248, 16384]
+518: model.layers.37.mlp.up_proj.weight shape: [53248, 16384]
+519: model.layers.37.post_attention_layernorm.weight shape: [16384]
+520: model.layers.37.self_attn.k_proj.weight shape: [1024, 16384]
+521: model.layers.37.self_attn.o_proj.weight shape: [16384, 16384]
+522: model.layers.37.self_attn.q_proj.weight shape: [16384, 16384]
+523: model.layers.37.self_attn.v_proj.weight shape: [1024, 16384]
+524: model.layers.38.input_layernorm.weight shape: [16384]
+525: model.layers.38.mlp.down_proj.weight shape: [16384, 53248]
+526: model.layers.38.mlp.gate_proj.weight shape: [53248, 16384]
+527: model.layers.38.mlp.up_proj.weight shape: [53248, 16384]
+528: model.layers.38.post_attention_layernorm.weight shape: [16384]
+529: model.layers.38.self_attn.k_proj.weight shape: [1024, 16384]
+530: model.layers.38.self_attn.o_proj.weight shape: [16384, 16384]
+531: model.layers.38.self_attn.q_proj.weight shape: [16384, 16384]
+532: model.layers.38.self_attn.v_proj.weight shape: [1024, 16384]
+533: model.layers.39.input_layernorm.weight shape: [16384]
+534: model.layers.39.mlp.down_proj.weight shape: [16384, 53248]
+535: model.layers.39.mlp.gate_proj.weight shape: [53248, 16384]
+536: model.layers.39.mlp.up_proj.weight shape: [53248, 16384]
+537: model.layers.39.post_attention_layernorm.weight shape: [16384]
+538: model.layers.39.self_attn.k_proj.weight shape: [1024, 16384]
+539: model.layers.39.self_attn.o_proj.weight shape: [16384, 16384]
+540: model.layers.39.self_attn.q_proj.weight shape: [16384, 16384]
+541: model.layers.39.self_attn.v_proj.weight shape: [1024, 16384]
+542: model.layers.4.input_layernorm.weight shape: [16384]
+543: model.layers.4.mlp.down_proj.weight shape: [16384, 53248]
+544: model.layers.4.mlp.gate_proj.weight shape: [53248, 16384]
+545: model.layers.4.mlp.up_proj.weight shape: [53248, 16384]
+546: model.layers.4.post_attention_layernorm.weight shape: [16384]
+547: model.layers.4.self_attn.k_proj.weight shape: [1024, 16384]
+548: model.layers.4.self_attn.o_proj.weight shape: [16384, 16384]
+549: model.layers.4.self_attn.q_proj.weight shape: [16384, 16384]
+550: model.layers.4.self_attn.v_proj.weight shape: [1024, 16384]
+551: model.layers.40.input_layernorm.weight shape: [16384]
+552: model.layers.40.mlp.down_proj.weight shape: [16384, 53248]
+553: model.layers.40.mlp.gate_proj.weight shape: [53248, 16384]
+554: model.layers.40.mlp.up_proj.weight shape: [53248, 16384]
+555: model.layers.40.post_attention_layernorm.weight shape: [16384]
+556: model.layers.40.self_attn.k_proj.weight shape: [1024, 16384]
+557: model.layers.40.self_attn.o_proj.weight shape: [16384, 16384]
+558: model.layers.40.self_attn.q_proj.weight shape: [16384, 16384]
+559: model.layers.40.self_attn.v_proj.weight shape: [1024, 16384]
+560: model.layers.41.input_layernorm.weight shape: [16384]
+561: model.layers.41.mlp.down_proj.weight shape: [16384, 53248]
+562: model.layers.41.mlp.gate_proj.weight shape: [53248, 16384]
+563: model.layers.41.mlp.up_proj.weight shape: [53248, 16384]
+564: model.layers.41.post_attention_layernorm.weight shape: [16384]
+565: model.layers.41.self_attn.k_proj.weight shape: [1024, 16384]
+566: model.layers.41.self_attn.o_proj.weight shape: [16384, 16384]
+567: model.layers.41.self_attn.q_proj.weight shape: [16384, 16384]
+568: model.layers.41.self_attn.v_proj.weight shape: [1024, 16384]
+569: model.layers.42.input_layernorm.weight shape: [16384]
+570: model.layers.42.mlp.down_proj.weight shape: [16384, 53248]
+571: model.layers.42.mlp.gate_proj.weight shape: [53248, 16384]
+572: model.layers.42.mlp.up_proj.weight shape: [53248, 16384]
+573: model.layers.42.post_attention_layernorm.weight shape: [16384]
+574: model.layers.42.self_attn.k_proj.weight shape: [1024, 16384]
+575: model.layers.42.self_attn.o_proj.weight shape: [16384, 16384]
+576: model.layers.42.self_attn.q_proj.weight shape: [16384, 16384]
+577: model.layers.42.self_attn.v_proj.weight shape: [1024, 16384]
+578: model.layers.43.input_layernorm.weight shape: [16384]
+579: model.layers.43.mlp.down_proj.weight shape: [16384, 53248]
+580: model.layers.43.mlp.gate_proj.weight shape: [53248, 16384]
+581: model.layers.43.mlp.up_proj.weight shape: [53248, 16384]
+582: model.layers.43.post_attention_layernorm.weight shape: [16384]
+583: model.layers.43.self_attn.k_proj.weight shape: [1024, 16384]
+584: model.layers.43.self_attn.o_proj.weight shape: [16384, 16384]
+585: model.layers.43.self_attn.q_proj.weight shape: [16384, 16384]
+586: model.layers.43.self_attn.v_proj.weight shape: [1024, 16384]
+587: model.layers.44.input_layernorm.weight shape: [16384]
+588: model.layers.44.mlp.down_proj.weight shape: [16384, 53248]
+589: model.layers.44.mlp.gate_proj.weight shape: [53248, 16384]
+590: model.layers.44.mlp.up_proj.weight shape: [53248, 16384]
+591: model.layers.44.post_attention_layernorm.weight shape: [16384]
+592: model.layers.44.self_attn.k_proj.weight shape: [1024, 16384]
+593: model.layers.44.self_attn.o_proj.weight shape: [16384, 16384]
+594: model.layers.44.self_attn.q_proj.weight shape: [16384, 16384]
+595: model.layers.44.self_attn.v_proj.weight shape: [1024, 16384]
+596: model.layers.45.input_layernorm.weight shape: [16384]
+597: model.layers.45.mlp.down_proj.weight shape: [16384, 53248]
+598: model.layers.45.mlp.gate_proj.weight shape: [53248, 16384]
+599: model.layers.45.mlp.up_proj.weight shape: [53248, 16384]
+600: model.layers.45.post_attention_layernorm.weight shape: [16384]
+601: model.layers.45.self_attn.k_proj.weight shape: [1024, 16384]
+602: model.layers.45.self_attn.o_proj.weight shape: [16384, 16384]
+603: model.layers.45.self_attn.q_proj.weight shape: [16384, 16384]
+604: model.layers.45.self_attn.v_proj.weight shape: [1024, 16384]
+605: model.layers.46.input_layernorm.weight shape: [16384]
+606: model.layers.46.mlp.down_proj.weight shape: [16384, 53248]
+607: model.layers.46.mlp.gate_proj.weight shape: [53248, 16384]
+608: model.layers.46.mlp.up_proj.weight shape: [53248, 16384]
+609: model.layers.46.post_attention_layernorm.weight shape: [16384]
+610: model.layers.46.self_attn.k_proj.weight shape: [1024, 16384]
+611: model.layers.46.self_attn.o_proj.weight shape: [16384, 16384]
+612: model.layers.46.self_attn.q_proj.weight shape: [16384, 16384]
+613: model.layers.46.self_attn.v_proj.weight shape: [1024, 16384]
+614: model.layers.47.input_layernorm.weight shape: [16384]
+615: model.layers.47.mlp.down_proj.weight shape: [16384, 53248]
+616: model.layers.47.mlp.gate_proj.weight shape: [53248, 16384]
+617: model.layers.47.mlp.up_proj.weight shape: [53248, 16384]
+618: model.layers.47.post_attention_layernorm.weight shape: [16384]
+619: model.layers.47.self_attn.k_proj.weight shape: [1024, 16384]
+620: model.layers.47.self_attn.o_proj.weight shape: [16384, 16384]
+621: model.layers.47.self_attn.q_proj.weight shape: [16384, 16384]
+622: model.layers.47.self_attn.v_proj.weight shape: [1024, 16384]
+623: model.layers.48.input_layernorm.weight shape: [16384]
+624: model.layers.48.mlp.down_proj.weight shape: [16384, 53248]
+625: model.layers.48.mlp.gate_proj.weight shape: [53248, 16384]
+626: model.layers.48.mlp.up_proj.weight shape: [53248, 16384]
+627: model.layers.48.post_attention_layernorm.weight shape: [16384]
+628: model.layers.48.self_attn.k_proj.weight shape: [1024, 16384]
+629: model.layers.48.self_attn.o_proj.weight shape: [16384, 16384]
+630: model.layers.48.self_attn.q_proj.weight shape: [16384, 16384]
+631: model.layers.48.self_attn.v_proj.weight shape: [1024, 16384]
+632: model.layers.49.input_layernorm.weight shape: [16384]
+633: model.layers.49.mlp.down_proj.weight shape: [16384, 53248]
+634: model.layers.49.mlp.gate_proj.weight shape: [53248, 16384]
+635: model.layers.49.mlp.up_proj.weight shape: [53248, 16384]
+636: model.layers.49.post_attention_layernorm.weight shape: [16384]
+637: model.layers.49.self_attn.k_proj.weight shape: [1024, 16384]
+638: model.layers.49.self_attn.o_proj.weight shape: [16384, 16384]
+639: model.layers.49.self_attn.q_proj.weight shape: [16384, 16384]
+640: model.layers.49.self_attn.v_proj.weight shape: [1024, 16384]
+641: model.layers.5.input_layernorm.weight shape: [16384]
+642: model.layers.5.mlp.down_proj.weight shape: [16384, 53248]
+643: model.layers.5.mlp.gate_proj.weight shape: [53248, 16384]
+644: model.layers.5.mlp.up_proj.weight shape: [53248, 16384]
+645: model.layers.5.post_attention_layernorm.weight shape: [16384]
+646: model.layers.5.self_attn.k_proj.weight shape: [1024, 16384]
+647: model.layers.5.self_attn.o_proj.weight shape: [16384, 16384]
+648: model.layers.5.self_attn.q_proj.weight shape: [16384, 16384]
+649: model.layers.5.self_attn.v_proj.weight shape: [1024, 16384]
+650: model.layers.50.input_layernorm.weight shape: [16384]
+651: model.layers.50.mlp.down_proj.weight shape: [16384, 53248]
+652: model.layers.50.mlp.gate_proj.weight shape: [53248, 16384]
+653: model.layers.50.mlp.up_proj.weight shape: [53248, 16384]
+654: model.layers.50.post_attention_layernorm.weight shape: [16384]
+655: model.layers.50.self_attn.k_proj.weight shape: [1024, 16384]
+656: model.layers.50.self_attn.o_proj.weight shape: [16384, 16384]
+657: model.layers.50.self_attn.q_proj.weight shape: [16384, 16384]
+658: model.layers.50.self_attn.v_proj.weight shape: [1024, 16384]
+659: model.layers.51.input_layernorm.weight shape: [16384]
+660: model.layers.51.mlp.down_proj.weight shape: [16384, 53248]
+661: model.layers.51.mlp.gate_proj.weight shape: [53248, 16384]
+662: model.layers.51.mlp.up_proj.weight shape: [53248, 16384]
+663: model.layers.51.post_attention_layernorm.weight shape: [16384]
+664: model.layers.51.self_attn.k_proj.weight shape: [1024, 16384]
+665: model.layers.51.self_attn.o_proj.weight shape: [16384, 16384]
+666: model.layers.51.self_attn.q_proj.weight shape: [16384, 16384]
+667: model.layers.51.self_attn.v_proj.weight shape: [1024, 16384]
+668: model.layers.52.input_layernorm.weight shape: [16384]
+669: model.layers.52.mlp.down_proj.weight shape: [16384, 53248]
+670: model.layers.52.mlp.gate_proj.weight shape: [53248, 16384]
+671: model.layers.52.mlp.up_proj.weight shape: [53248, 16384]
+672: model.layers.52.post_attention_layernorm.weight shape: [16384]
+673: model.layers.52.self_attn.k_proj.weight shape: [1024, 16384]
+674: model.layers.52.self_attn.o_proj.weight shape: [16384, 16384]
+675: model.layers.52.self_attn.q_proj.weight shape: [16384, 16384]
+676: model.layers.52.self_attn.v_proj.weight shape: [1024, 16384]
+677: model.layers.53.input_layernorm.weight shape: [16384]
+678: model.layers.53.mlp.down_proj.weight shape: [16384, 53248]
+679: model.layers.53.mlp.gate_proj.weight shape: [53248, 16384]
+680: model.layers.53.mlp.up_proj.weight shape: [53248, 16384]
+681: model.layers.53.post_attention_layernorm.weight shape: [16384]
+682: model.layers.53.self_attn.k_proj.weight shape: [1024, 16384]
+683: model.layers.53.self_attn.o_proj.weight shape: [16384, 16384]
+684: model.layers.53.self_attn.q_proj.weight shape: [16384, 16384]
+685: model.layers.53.self_attn.v_proj.weight shape: [1024, 16384]
+686: model.layers.54.input_layernorm.weight shape: [16384]
+687: model.layers.54.mlp.down_proj.weight shape: [16384, 53248]
+688: model.layers.54.mlp.gate_proj.weight shape: [53248, 16384]
+689: model.layers.54.mlp.up_proj.weight shape: [53248, 16384]
+690: model.layers.54.post_attention_layernorm.weight shape: [16384]
+691: model.layers.54.self_attn.k_proj.weight shape: [1024, 16384]
+692: model.layers.54.self_attn.o_proj.weight shape: [16384, 16384]
+693: model.layers.54.self_attn.q_proj.weight shape: [16384, 16384]
+694: model.layers.54.self_attn.v_proj.weight shape: [1024, 16384]
+695: model.layers.55.input_layernorm.weight shape: [16384]
+696: model.layers.55.mlp.down_proj.weight shape: [16384, 53248]
+697: model.layers.55.mlp.gate_proj.weight shape: [53248, 16384]
+698: model.layers.55.mlp.up_proj.weight shape: [53248, 16384]
+699: model.layers.55.post_attention_layernorm.weight shape: [16384]
+700: model.layers.55.self_attn.k_proj.weight shape: [1024, 16384]
+701: model.layers.55.self_attn.o_proj.weight shape: [16384, 16384]
+702: model.layers.55.self_attn.q_proj.weight shape: [16384, 16384]
+703: model.layers.55.self_attn.v_proj.weight shape: [1024, 16384]
+704: model.layers.56.input_layernorm.weight shape: [16384]
+705: model.layers.56.mlp.down_proj.weight shape: [16384, 53248]
+706: model.layers.56.mlp.gate_proj.weight shape: [53248, 16384]
+707: model.layers.56.mlp.up_proj.weight shape: [53248, 16384]
+708: model.layers.56.post_attention_layernorm.weight shape: [16384]
+709: model.layers.56.self_attn.k_proj.weight shape: [1024, 16384]
+710: model.layers.56.self_attn.o_proj.weight shape: [16384, 16384]
+711: model.layers.56.self_attn.q_proj.weight shape: [16384, 16384]
+712: model.layers.56.self_attn.v_proj.weight shape: [1024, 16384]
+713: model.layers.57.input_layernorm.weight shape: [16384]
+714: model.layers.57.mlp.down_proj.weight shape: [16384, 53248]
+715: model.layers.57.mlp.gate_proj.weight shape: [53248, 16384]
+716: model.layers.57.mlp.up_proj.weight shape: [53248, 16384]
+717: model.layers.57.post_attention_layernorm.weight shape: [16384]
+718: model.layers.57.self_attn.k_proj.weight shape: [1024, 16384]
+719: model.layers.57.self_attn.o_proj.weight shape: [16384, 16384]
+720: model.layers.57.self_attn.q_proj.weight shape: [16384, 16384]
+721: model.layers.57.self_attn.v_proj.weight shape: [1024, 16384]
+722: model.layers.58.input_layernorm.weight shape: [16384]
+723: model.layers.58.mlp.down_proj.weight shape: [16384, 53248]
+724: model.layers.58.mlp.gate_proj.weight shape: [53248, 16384]
+725: model.layers.58.mlp.up_proj.weight shape: [53248, 16384]
+726: model.layers.58.post_attention_layernorm.weight shape: [16384]
+727: model.layers.58.self_attn.k_proj.weight shape: [1024, 16384]
+728: model.layers.58.self_attn.o_proj.weight shape: [16384, 16384]
+729: model.layers.58.self_attn.q_proj.weight shape: [16384, 16384]
+730: model.layers.58.self_attn.v_proj.weight shape: [1024, 16384]
+731: model.layers.59.input_layernorm.weight shape: [16384]
+732: model.layers.59.mlp.down_proj.weight shape: [16384, 53248]
+733: model.layers.59.mlp.gate_proj.weight shape: [53248, 16384]
+734: model.layers.59.mlp.up_proj.weight shape: [53248, 16384]
+735: model.layers.59.post_attention_layernorm.weight shape: [16384]
+736: model.layers.59.self_attn.k_proj.weight shape: [1024, 16384]
+737: model.layers.59.self_attn.o_proj.weight shape: [16384, 16384]
+738: model.layers.59.self_attn.q_proj.weight shape: [16384, 16384]
+739: model.layers.59.self_attn.v_proj.weight shape: [1024, 16384]
+740: model.layers.6.input_layernorm.weight shape: [16384]
+741: model.layers.6.mlp.down_proj.weight shape: [16384, 53248]
+742: model.layers.6.mlp.gate_proj.weight shape: [53248, 16384]
+743: model.layers.6.mlp.up_proj.weight shape: [53248, 16384]
+744: model.layers.6.post_attention_layernorm.weight shape: [16384]
+745: model.layers.6.self_attn.k_proj.weight shape: [1024, 16384]
+746: model.layers.6.self_attn.o_proj.weight shape: [16384, 16384]
+747: model.layers.6.self_attn.q_proj.weight shape: [16384, 16384]
+748: model.layers.6.self_attn.v_proj.weight shape: [1024, 16384]
+749: model.layers.60.input_layernorm.weight shape: [16384]
+750: model.layers.60.mlp.down_proj.weight shape: [16384, 53248]
+751: model.layers.60.mlp.gate_proj.weight shape: [53248, 16384]
+752: model.layers.60.mlp.up_proj.weight shape: [53248, 16384]
+753: model.layers.60.post_attention_layernorm.weight shape: [16384]
+754: model.layers.60.self_attn.k_proj.weight shape: [1024, 16384]
+755: model.layers.60.self_attn.o_proj.weight shape: [16384, 16384]
+756: model.layers.60.self_attn.q_proj.weight shape: [16384, 16384]
+757: model.layers.60.self_attn.v_proj.weight shape: [1024, 16384]
+758: model.layers.61.input_layernorm.weight shape: [16384]
+759: model.layers.61.mlp.down_proj.weight shape: [16384, 53248]
+760: model.layers.61.mlp.gate_proj.weight shape: [53248, 16384]
+761: model.layers.61.mlp.up_proj.weight shape: [53248, 16384]
+762: model.layers.61.post_attention_layernorm.weight shape: [16384]
+763: model.layers.61.self_attn.k_proj.weight shape: [1024, 16384]
+764: model.layers.61.self_attn.o_proj.weight shape: [16384, 16384]
+765: model.layers.61.self_attn.q_proj.weight shape: [16384, 16384]
+766: model.layers.61.self_attn.v_proj.weight shape: [1024, 16384]
+767: model.layers.62.input_layernorm.weight shape: [16384]
+768: model.layers.62.mlp.down_proj.weight shape: [16384, 53248]
+769: model.layers.62.mlp.gate_proj.weight shape: [53248, 16384]
+770: model.layers.62.mlp.up_proj.weight shape: [53248, 16384]
+771: model.layers.62.post_attention_layernorm.weight shape: [16384]
+772: model.layers.62.self_attn.k_proj.weight shape: [1024, 16384]
+773: model.layers.62.self_attn.o_proj.weight shape: [16384, 16384]
+774: model.layers.62.self_attn.q_proj.weight shape: [16384, 16384]
+775: model.layers.62.self_attn.v_proj.weight shape: [1024, 16384]
+776: model.layers.63.input_layernorm.weight shape: [16384]
+777: model.layers.63.mlp.down_proj.weight shape: [16384, 53248]
+778: model.layers.63.mlp.gate_proj.weight shape: [53248, 16384]
+779: model.layers.63.mlp.up_proj.weight shape: [53248, 16384]
+780: model.layers.63.post_attention_layernorm.weight shape: [16384]
+781: model.layers.63.self_attn.k_proj.weight shape: [1024, 16384]
+782: model.layers.63.self_attn.o_proj.weight shape: [16384, 16384]
+783: model.layers.63.self_attn.q_proj.weight shape: [16384, 16384]
+784: model.layers.63.self_attn.v_proj.weight shape: [1024, 16384]
+785: model.layers.64.input_layernorm.weight shape: [16384]
+786: model.layers.64.mlp.down_proj.weight shape: [16384, 53248]
+787: model.layers.64.mlp.gate_proj.weight shape: [53248, 16384]
+788: model.layers.64.mlp.up_proj.weight shape: [53248, 16384]
+789: model.layers.64.post_attention_layernorm.weight shape: [16384]
+790: model.layers.64.self_attn.k_proj.weight shape: [1024, 16384]
+791: model.layers.64.self_attn.o_proj.weight shape: [16384, 16384]
+792: model.layers.64.self_attn.q_proj.weight shape: [16384, 16384]
+793: model.layers.64.self_attn.v_proj.weight shape: [1024, 16384]
+794: model.layers.65.input_layernorm.weight shape: [16384]
+795: model.layers.65.mlp.down_proj.weight shape: [16384, 53248]
+796: model.layers.65.mlp.gate_proj.weight shape: [53248, 16384]
+797: model.layers.65.mlp.up_proj.weight shape: [53248, 16384]
+798: model.layers.65.post_attention_layernorm.weight shape: [16384]
+799: model.layers.65.self_attn.k_proj.weight shape: [1024, 16384]
+800: model.layers.65.self_attn.o_proj.weight shape: [16384, 16384]
+801: model.layers.65.self_attn.q_proj.weight shape: [16384, 16384]
+802: model.layers.65.self_attn.v_proj.weight shape: [1024, 16384]
+803: model.layers.66.input_layernorm.weight shape: [16384]
+804: model.layers.66.mlp.down_proj.weight shape: [16384, 53248]
+805: model.layers.66.mlp.gate_proj.weight shape: [53248, 16384]
+806: model.layers.66.mlp.up_proj.weight shape: [53248, 16384]
+807: model.layers.66.post_attention_layernorm.weight shape: [16384]
+808: model.layers.66.self_attn.k_proj.weight shape: [1024, 16384]
+809: model.layers.66.self_attn.o_proj.weight shape: [16384, 16384]
+810: model.layers.66.self_attn.q_proj.weight shape: [16384, 16384]
+811: model.layers.66.self_attn.v_proj.weight shape: [1024, 16384]
+812: model.layers.67.input_layernorm.weight shape: [16384]
+813: model.layers.67.mlp.down_proj.weight shape: [16384, 53248]
+814: model.layers.67.mlp.gate_proj.weight shape: [53248, 16384]
+815: model.layers.67.mlp.up_proj.weight shape: [53248, 16384]
+816: model.layers.67.post_attention_layernorm.weight shape: [16384]
+817: model.layers.67.self_attn.k_proj.weight shape: [1024, 16384]
+818: model.layers.67.self_attn.o_proj.weight shape: [16384, 16384]
+819: model.layers.67.self_attn.q_proj.weight shape: [16384, 16384]
+820: model.layers.67.self_attn.v_proj.weight shape: [1024, 16384]
+821: model.layers.68.input_layernorm.weight shape: [16384]
+822: model.layers.68.mlp.down_proj.weight shape: [16384, 53248]
+823: model.layers.68.mlp.gate_proj.weight shape: [53248, 16384]
+824: model.layers.68.mlp.up_proj.weight shape: [53248, 16384]
+825: model.layers.68.post_attention_layernorm.weight shape: [16384]
+826: model.layers.68.self_attn.k_proj.weight shape: [1024, 16384]
+827: model.layers.68.self_attn.o_proj.weight shape: [16384, 16384]
+828: model.layers.68.self_attn.q_proj.weight shape: [16384, 16384]
+829: model.layers.68.self_attn.v_proj.weight shape: [1024, 16384]
+830: model.layers.69.input_layernorm.weight shape: [16384]
+831: model.layers.69.mlp.down_proj.weight shape: [16384, 53248]
+832: model.layers.69.mlp.gate_proj.weight shape: [53248, 16384]
+833: model.layers.69.mlp.up_proj.weight shape: [53248, 16384]
+834: model.layers.69.post_attention_layernorm.weight shape: [16384]
+835: model.layers.69.self_attn.k_proj.weight shape: [1024, 16384]
+836: model.layers.69.self_attn.o_proj.weight shape: [16384, 16384]
+837: model.layers.69.self_attn.q_proj.weight shape: [16384, 16384]
+838: model.layers.69.self_attn.v_proj.weight shape: [1024, 16384]
+839: model.layers.7.input_layernorm.weight shape: [16384]
+840: model.layers.7.mlp.down_proj.weight shape: [16384, 53248]
+841: model.layers.7.mlp.gate_proj.weight shape: [53248, 16384]
+842: model.layers.7.mlp.up_proj.weight shape: [53248, 16384]
+843: model.layers.7.post_attention_layernorm.weight shape: [16384]
+844: model.layers.7.self_attn.k_proj.weight shape: [1024, 16384]
+845: model.layers.7.self_attn.o_proj.weight shape: [16384, 16384]
+846: model.layers.7.self_attn.q_proj.weight shape: [16384, 16384]
+847: model.layers.7.self_attn.v_proj.weight shape: [1024, 16384]
+848: model.layers.70.input_layernorm.weight shape: [16384]
+849: model.layers.70.mlp.down_proj.weight shape: [16384, 53248]
+850: model.layers.70.mlp.gate_proj.weight shape: [53248, 16384]
+851: model.layers.70.mlp.up_proj.weight shape: [53248, 16384]
+852: model.layers.70.post_attention_layernorm.weight shape: [16384]
+853: model.layers.70.self_attn.k_proj.weight shape: [1024, 16384]
+854: model.layers.70.self_attn.o_proj.weight shape: [16384, 16384]
+855: model.layers.70.self_attn.q_proj.weight shape: [16384, 16384]
+856: model.layers.70.self_attn.v_proj.weight shape: [1024, 16384]
+857: model.layers.71.input_layernorm.weight shape: [16384]
+858: model.layers.71.mlp.down_proj.weight shape: [16384, 53248]
+859: model.layers.71.mlp.gate_proj.weight shape: [53248, 16384]
+860: model.layers.71.mlp.up_proj.weight shape: [53248, 16384]
+861: model.layers.71.post_attention_layernorm.weight shape: [16384]
+862: model.layers.71.self_attn.k_proj.weight shape: [1024, 16384]
+863: model.layers.71.self_attn.o_proj.weight shape: [16384, 16384]
+864: model.layers.71.self_attn.q_proj.weight shape: [16384, 16384]
+865: model.layers.71.self_attn.v_proj.weight shape: [1024, 16384]
+866: model.layers.72.input_layernorm.weight shape: [16384]
+867: model.layers.72.mlp.down_proj.weight shape: [16384, 53248]
+868: model.layers.72.mlp.gate_proj.weight shape: [53248, 16384]
+869: model.layers.72.mlp.up_proj.weight shape: [53248, 16384]
+870: model.layers.72.post_attention_layernorm.weight shape: [16384]
+871: model.layers.72.self_attn.k_proj.weight shape: [1024, 16384]
+872: model.layers.72.self_attn.o_proj.weight shape: [16384, 16384]
+873: model.layers.72.self_attn.q_proj.weight shape: [16384, 16384]
+874: model.layers.72.self_attn.v_proj.weight shape: [1024, 16384]
+875: model.layers.73.input_layernorm.weight shape: [16384]
+876: model.layers.73.mlp.down_proj.weight shape: [16384, 53248]
+877: model.layers.73.mlp.gate_proj.weight shape: [53248, 16384]
+878: model.layers.73.mlp.up_proj.weight shape: [53248, 16384]
+879: model.layers.73.post_attention_layernorm.weight shape: [16384]
+880: model.layers.73.self_attn.k_proj.weight shape: [1024, 16384]
+881: model.layers.73.self_attn.o_proj.weight shape: [16384, 16384]
+882: model.layers.73.self_attn.q_proj.weight shape: [16384, 16384]
+883: model.layers.73.self_attn.v_proj.weight shape: [1024, 16384]
+884: model.layers.74.input_layernorm.weight shape: [16384]
+885: model.layers.74.mlp.down_proj.weight shape: [16384, 53248]
+886: model.layers.74.mlp.gate_proj.weight shape: [53248, 16384]
+887: model.layers.74.mlp.up_proj.weight shape: [53248, 16384]
+888: model.layers.74.post_attention_layernorm.weight shape: [16384]
+889: model.layers.74.self_attn.k_proj.weight shape: [1024, 16384]
+890: model.layers.74.self_attn.o_proj.weight shape: [16384, 16384]
+891: model.layers.74.self_attn.q_proj.weight shape: [16384, 16384]
+892: model.layers.74.self_attn.v_proj.weight shape: [1024, 16384]
+893: model.layers.75.input_layernorm.weight shape: [16384]
+894: model.layers.75.mlp.down_proj.weight shape: [16384, 53248]
+895: model.layers.75.mlp.gate_proj.weight shape: [53248, 16384]
+896: model.layers.75.mlp.up_proj.weight shape: [53248, 16384]
+897: model.layers.75.post_attention_layernorm.weight shape: [16384]
+898: model.layers.75.self_attn.k_proj.weight shape: [1024, 16384]
+899: model.layers.75.self_attn.o_proj.weight shape: [16384, 16384]
+900: model.layers.75.self_attn.q_proj.weight shape: [16384, 16384]
+901: model.layers.75.self_attn.v_proj.weight shape: [1024, 16384]
+902: model.layers.76.input_layernorm.weight shape: [16384]
+903: model.layers.76.mlp.down_proj.weight shape: [16384, 53248]
+904: model.layers.76.mlp.gate_proj.weight shape: [53248, 16384]
+905: model.layers.76.mlp.up_proj.weight shape: [53248, 16384]
+906: model.layers.76.post_attention_layernorm.weight shape: [16384]
+907: model.layers.76.self_attn.k_proj.weight shape: [1024, 16384]
+908: model.layers.76.self_attn.o_proj.weight shape: [16384, 16384]
+909: model.layers.76.self_attn.q_proj.weight shape: [16384, 16384]
+910: model.layers.76.self_attn.v_proj.weight shape: [1024, 16384]
+911: model.layers.77.input_layernorm.weight shape: [16384]
+912: model.layers.77.mlp.down_proj.weight shape: [16384, 53248]
+913: model.layers.77.mlp.gate_proj.weight shape: [53248, 16384]
+914: model.layers.77.mlp.up_proj.weight shape: [53248, 16384]
+915: model.layers.77.post_attention_layernorm.weight shape: [16384]
+916: model.layers.77.self_attn.k_proj.weight shape: [1024, 16384]
+917: model.layers.77.self_attn.o_proj.weight shape: [16384, 16384]
+918: model.layers.77.self_attn.q_proj.weight shape: [16384, 16384]
+919: model.layers.77.self_attn.v_proj.weight shape: [1024, 16384]
+920: model.layers.78.input_layernorm.weight shape: [16384]
+921: model.layers.78.mlp.down_proj.weight shape: [16384, 53248]
+922: model.layers.78.mlp.gate_proj.weight shape: [53248, 16384]
+923: model.layers.78.mlp.up_proj.weight shape: [53248, 16384]
+924: model.layers.78.post_attention_layernorm.weight shape: [16384]
+925: model.layers.78.self_attn.k_proj.weight shape: [1024, 16384]
+926: model.layers.78.self_attn.o_proj.weight shape: [16384, 16384]
+927: model.layers.78.self_attn.q_proj.weight shape: [16384, 16384]
+928: model.layers.78.self_attn.v_proj.weight shape: [1024, 16384]
+929: model.layers.79.input_layernorm.weight shape: [16384]
+930: model.layers.79.mlp.down_proj.weight shape: [16384, 53248]
+931: model.layers.79.mlp.gate_proj.weight shape: [53248, 16384]
+932: model.layers.79.mlp.up_proj.weight shape: [53248, 16384]
+933: model.layers.79.post_attention_layernorm.weight shape: [16384]
+934: model.layers.79.self_attn.k_proj.weight shape: [1024, 16384]
+935: model.layers.79.self_attn.o_proj.weight shape: [16384, 16384]
+936: model.layers.79.self_attn.q_proj.weight shape: [16384, 16384]
+937: model.layers.79.self_attn.v_proj.weight shape: [1024, 16384]
+938: model.layers.8.input_layernorm.weight shape: [16384]
+939: model.layers.8.mlp.down_proj.weight shape: [16384, 53248]
+940: model.layers.8.mlp.gate_proj.weight shape: [53248, 16384]
+941: model.layers.8.mlp.up_proj.weight shape: [53248, 16384]
+942: model.layers.8.post_attention_layernorm.weight shape: [16384]
+943: model.layers.8.self_attn.k_proj.weight shape: [1024, 16384]
+944: model.layers.8.self_attn.o_proj.weight shape: [16384, 16384]
+945: model.layers.8.self_attn.q_proj.weight shape: [16384, 16384]
+946: model.layers.8.self_attn.v_proj.weight shape: [1024, 16384]
+947: model.layers.80.input_layernorm.weight shape: [16384]
+948: model.layers.80.mlp.down_proj.weight shape: [16384, 53248]
+949: model.layers.80.mlp.gate_proj.weight shape: [53248, 16384]
+950: model.layers.80.mlp.up_proj.weight shape: [53248, 16384]
+951: model.layers.80.post_attention_layernorm.weight shape: [16384]
+952: model.layers.80.self_attn.k_proj.weight shape: [1024, 16384]
+953: model.layers.80.self_attn.o_proj.weight shape: [16384, 16384]
+954: model.layers.80.self_attn.q_proj.weight shape: [16384, 16384]
+955: model.layers.80.self_attn.v_proj.weight shape: [1024, 16384]
+956: model.layers.81.input_layernorm.weight shape: [16384]
+957: model.layers.81.mlp.down_proj.weight shape: [16384, 53248]
+958: model.layers.81.mlp.gate_proj.weight shape: [53248, 16384]
+959: model.layers.81.mlp.up_proj.weight shape: [53248, 16384]
+960: model.layers.81.post_attention_layernorm.weight shape: [16384]
+961: model.layers.81.self_attn.k_proj.weight shape: [1024, 16384]
+962: model.layers.81.self_attn.o_proj.weight shape: [16384, 16384]
+963: model.layers.81.self_attn.q_proj.weight shape: [16384, 16384]
+964: model.layers.81.self_attn.v_proj.weight shape: [1024, 16384]
+965: model.layers.82.input_layernorm.weight shape: [16384]
+966: model.layers.82.mlp.down_proj.weight shape: [16384, 53248]
+967: model.layers.82.mlp.gate_proj.weight shape: [53248, 16384]
+968: model.layers.82.mlp.up_proj.weight shape: [53248, 16384]
+969: model.layers.82.post_attention_layernorm.weight shape: [16384]
+970: model.layers.82.self_attn.k_proj.weight shape: [1024, 16384]
+971: model.layers.82.self_attn.o_proj.weight shape: [16384, 16384]
+972: model.layers.82.self_attn.q_proj.weight shape: [16384, 16384]
+973: model.layers.82.self_attn.v_proj.weight shape: [1024, 16384]
+974: model.layers.83.input_layernorm.weight shape: [16384]
+975: model.layers.83.mlp.down_proj.weight shape: [16384, 53248]
+976: model.layers.83.mlp.gate_proj.weight shape: [53248, 16384]
+977: model.layers.83.mlp.up_proj.weight shape: [53248, 16384]
+978: model.layers.83.post_attention_layernorm.weight shape: [16384]
+979: model.layers.83.self_attn.k_proj.weight shape: [1024, 16384]
+980: model.layers.83.self_attn.o_proj.weight shape: [16384, 16384]
+981: model.layers.83.self_attn.q_proj.weight shape: [16384, 16384]
+982: model.layers.83.self_attn.v_proj.weight shape: [1024, 16384]
+983: model.layers.84.input_layernorm.weight shape: [16384]
+984: model.layers.84.mlp.down_proj.weight shape: [16384, 53248]
+985: model.layers.84.mlp.gate_proj.weight shape: [53248, 16384]
+986: model.layers.84.mlp.up_proj.weight shape: [53248, 16384]
+987: model.layers.84.post_attention_layernorm.weight shape: [16384]
+988: model.layers.84.self_attn.k_proj.weight shape: [1024, 16384]
+989: model.layers.84.self_attn.o_proj.weight shape: [16384, 16384]
+990: model.layers.84.self_attn.q_proj.weight shape: [16384, 16384]
+991: model.layers.84.self_attn.v_proj.weight shape: [1024, 16384]
+992: model.layers.85.input_layernorm.weight shape: [16384]
+993: model.layers.85.mlp.down_proj.weight shape: [16384, 53248]
+994: model.layers.85.mlp.gate_proj.weight shape: [53248, 16384]
+995: model.layers.85.mlp.up_proj.weight shape: [53248, 16384]
+996: model.layers.85.post_attention_layernorm.weight shape: [16384]
+997: model.layers.85.self_attn.k_proj.weight shape: [1024, 16384]
+998: model.layers.85.self_attn.o_proj.weight shape: [16384, 16384]
+999: model.layers.85.self_attn.q_proj.weight shape: [16384, 16384]
+1000: model.layers.85.self_attn.v_proj.weight shape: [1024, 16384]
+1001: model.layers.86.input_layernorm.weight shape: [16384]
+1002: model.layers.86.mlp.down_proj.weight shape: [16384, 53248]
+1003: model.layers.86.mlp.gate_proj.weight shape: [53248, 16384]
+1004: model.layers.86.mlp.up_proj.weight shape: [53248, 16384]
+1005: model.layers.86.post_attention_layernorm.weight shape: [16384]
+1006: model.layers.86.self_attn.k_proj.weight shape: [1024, 16384]
+1007: model.layers.86.self_attn.o_proj.weight shape: [16384, 16384]
+1008: model.layers.86.self_attn.q_proj.weight shape: [16384, 16384]
+1009: model.layers.86.self_attn.v_proj.weight shape: [1024, 16384]
+1010: model.layers.87.input_layernorm.weight shape: [16384]
+1011: model.layers.87.mlp.down_proj.weight shape: [16384, 53248]
+1012: model.layers.87.mlp.gate_proj.weight shape: [53248, 16384]
+1013: model.layers.87.mlp.up_proj.weight shape: [53248, 16384]
+1014: model.layers.87.post_attention_layernorm.weight shape: [16384]
+1015: model.layers.87.self_attn.k_proj.weight shape: [1024, 16384]
+1016: model.layers.87.self_attn.o_proj.weight shape: [16384, 16384]
+1017: model.layers.87.self_attn.q_proj.weight shape: [16384, 16384]
+1018: model.layers.87.self_attn.v_proj.weight shape: [1024, 16384]
+1019: model.layers.88.input_layernorm.weight shape: [16384]
+1020: model.layers.88.mlp.down_proj.weight shape: [16384, 53248]
+1021: model.layers.88.mlp.gate_proj.weight shape: [53248, 16384]
+1022: model.layers.88.mlp.up_proj.weight shape: [53248, 16384]
+1023: model.layers.88.post_attention_layernorm.weight shape: [16384]
+1024: model.layers.88.self_attn.k_proj.weight shape: [1024, 16384]
+1025: model.layers.88.self_attn.o_proj.weight shape: [16384, 16384]
+1026: model.layers.88.self_attn.q_proj.weight shape: [16384, 16384]
+1027: model.layers.88.self_attn.v_proj.weight shape: [1024, 16384]
+1028: model.layers.89.input_layernorm.weight shape: [16384]
+1029: model.layers.89.mlp.down_proj.weight shape: [16384, 53248]
+1030: model.layers.89.mlp.gate_proj.weight shape: [53248, 16384]
+1031: model.layers.89.mlp.up_proj.weight shape: [53248, 16384]
+1032: model.layers.89.post_attention_layernorm.weight shape: [16384]
+1033: model.layers.89.self_attn.k_proj.weight shape: [1024, 16384]
+1034: model.layers.89.self_attn.o_proj.weight shape: [16384, 16384]
+1035: model.layers.89.self_attn.q_proj.weight shape: [16384, 16384]
+1036: model.layers.89.self_attn.v_proj.weight shape: [1024, 16384]
+1037: model.layers.9.input_layernorm.weight shape: [16384]
+1038: model.layers.9.mlp.down_proj.weight shape: [16384, 53248]
+1039: model.layers.9.mlp.gate_proj.weight shape: [53248, 16384]
+1040: model.layers.9.mlp.up_proj.weight shape: [53248, 16384]
+1041: model.layers.9.post_attention_layernorm.weight shape: [16384]
+1042: model.layers.9.self_attn.k_proj.weight shape: [1024, 16384]
+1043: model.layers.9.self_attn.o_proj.weight shape: [16384, 16384]
+1044: model.layers.9.self_attn.q_proj.weight shape: [16384, 16384]
+1045: model.layers.9.self_attn.v_proj.weight shape: [1024, 16384]
+1046: model.layers.90.input_layernorm.weight shape: [16384]
+1047: model.layers.90.mlp.down_proj.weight shape: [16384, 53248]
+1048: model.layers.90.mlp.gate_proj.weight shape: [53248, 16384]
+1049: model.layers.90.mlp.up_proj.weight shape: [53248, 16384]
+1050: model.layers.90.post_attention_layernorm.weight shape: [16384]
+1051: model.layers.90.self_attn.k_proj.weight shape: [1024, 16384]
+1052: model.layers.90.self_attn.o_proj.weight shape: [16384, 16384]
+1053: model.layers.90.self_attn.q_proj.weight shape: [16384, 16384]
+1054: model.layers.90.self_attn.v_proj.weight shape: [1024, 16384]
+1055: model.layers.91.input_layernorm.weight shape: [16384]
+1056: model.layers.91.mlp.down_proj.weight shape: [16384, 53248]
+1057: model.layers.91.mlp.gate_proj.weight shape: [53248, 16384]
+1058: model.layers.91.mlp.up_proj.weight shape: [53248, 16384]
+1059: model.layers.91.post_attention_layernorm.weight shape: [16384]
+1060: model.layers.91.self_attn.k_proj.weight shape: [1024, 16384]
+1061: model.layers.91.self_attn.o_proj.weight shape: [16384, 16384]
+1062: model.layers.91.self_attn.q_proj.weight shape: [16384, 16384]
+1063: model.layers.91.self_attn.v_proj.weight shape: [1024, 16384]
+1064: model.layers.92.input_layernorm.weight shape: [16384]
+1065: model.layers.92.mlp.down_proj.weight shape: [16384, 53248]
+1066: model.layers.92.mlp.gate_proj.weight shape: [53248, 16384]
+1067: model.layers.92.mlp.up_proj.weight shape: [53248, 16384]
+1068: model.layers.92.post_attention_layernorm.weight shape: [16384]
+1069: model.layers.92.self_attn.k_proj.weight shape: [1024, 16384]
+1070: model.layers.92.self_attn.o_proj.weight shape: [16384, 16384]
+1071: model.layers.92.self_attn.q_proj.weight shape: [16384, 16384]
+1072: model.layers.92.self_attn.v_proj.weight shape: [1024, 16384]
+1073: model.layers.93.input_layernorm.weight shape: [16384]
+1074: model.layers.93.mlp.down_proj.weight shape: [16384, 53248]
+1075: model.layers.93.mlp.gate_proj.weight shape: [53248, 16384]
+1076: model.layers.93.mlp.up_proj.weight shape: [53248, 16384]
+1077: model.layers.93.post_attention_layernorm.weight shape: [16384]
+1078: model.layers.93.self_attn.k_proj.weight shape: [1024, 16384]
+1079: model.layers.93.self_attn.o_proj.weight shape: [16384, 16384]
+1080: model.layers.93.self_attn.q_proj.weight shape: [16384, 16384]
+1081: model.layers.93.self_attn.v_proj.weight shape: [1024, 16384]
+1082: model.layers.94.input_layernorm.weight shape: [16384]
+1083: model.layers.94.mlp.down_proj.weight shape: [16384, 53248]
+1084: model.layers.94.mlp.gate_proj.weight shape: [53248, 16384]
+1085: model.layers.94.mlp.up_proj.weight shape: [53248, 16384]
+1086: model.layers.94.post_attention_layernorm.weight shape: [16384]
+1087: model.layers.94.self_attn.k_proj.weight shape: [1024, 16384]
+1088: model.layers.94.self_attn.o_proj.weight shape: [16384, 16384]
+1089: model.layers.94.self_attn.q_proj.weight shape: [16384, 16384]
+1090: model.layers.94.self_attn.v_proj.weight shape: [1024, 16384]
+1091: model.layers.95.input_layernorm.weight shape: [16384]
+1092: model.layers.95.mlp.down_proj.weight shape: [16384, 53248]
+1093: model.layers.95.mlp.gate_proj.weight shape: [53248, 16384]
+1094: model.layers.95.mlp.up_proj.weight shape: [53248, 16384]
+1095: model.layers.95.post_attention_layernorm.weight shape: [16384]
+1096: model.layers.95.self_attn.k_proj.weight shape: [1024, 16384]
+1097: model.layers.95.self_attn.o_proj.weight shape: [16384, 16384]
+1098: model.layers.95.self_attn.q_proj.weight shape: [16384, 16384]
+1099: model.layers.95.self_attn.v_proj.weight shape: [1024, 16384]
+1100: model.layers.96.input_layernorm.weight shape: [16384]
+1101: model.layers.96.mlp.down_proj.weight shape: [16384, 53248]
+1102: model.layers.96.mlp.gate_proj.weight shape: [53248, 16384]
+1103: model.layers.96.mlp.up_proj.weight shape: [53248, 16384]
+1104: model.layers.96.post_attention_layernorm.weight shape: [16384]
+1105: model.layers.96.self_attn.k_proj.weight shape: [1024, 16384]
+1106: model.layers.96.self_attn.o_proj.weight shape: [16384, 16384]
+1107: model.layers.96.self_attn.q_proj.weight shape: [16384, 16384]
+1108: model.layers.96.self_attn.v_proj.weight shape: [1024, 16384]
+1109: model.layers.97.input_layernorm.weight shape: [16384]
+1110: model.layers.97.mlp.down_proj.weight shape: [16384, 53248]
+1111: model.layers.97.mlp.gate_proj.weight shape: [53248, 16384]
+1112: model.layers.97.mlp.up_proj.weight shape: [53248, 16384]
+1113: model.layers.97.post_attention_layernorm.weight shape: [16384]
+1114: model.layers.97.self_attn.k_proj.weight shape: [1024, 16384]
+1115: model.layers.97.self_attn.o_proj.weight shape: [16384, 16384]
+1116: model.layers.97.self_attn.q_proj.weight shape: [16384, 16384]
+1117: model.layers.97.self_attn.v_proj.weight shape: [1024, 16384]
+1118: model.layers.98.input_layernorm.weight shape: [16384]
+1119: model.layers.98.mlp.down_proj.weight shape: [16384, 53248]
+1120: model.layers.98.mlp.gate_proj.weight shape: [53248, 16384]
+1121: model.layers.98.mlp.up_proj.weight shape: [53248, 16384]
+1122: model.layers.98.post_attention_layernorm.weight shape: [16384]
+1123: model.layers.98.self_attn.k_proj.weight shape: [1024, 16384]
+1124: model.layers.98.self_attn.o_proj.weight shape: [16384, 16384]
+1125: model.layers.98.self_attn.q_proj.weight shape: [16384, 16384]
+1126: model.layers.98.self_attn.v_proj.weight shape: [1024, 16384]
+1127: model.layers.99.input_layernorm.weight shape: [16384]
+1128: model.layers.99.mlp.down_proj.weight shape: [16384, 53248]
+1129: model.layers.99.mlp.gate_proj.weight shape: [53248, 16384]
+1130: model.layers.99.mlp.up_proj.weight shape: [53248, 16384]
+1131: model.layers.99.post_attention_layernorm.weight shape: [16384]
+1132: model.layers.99.self_attn.k_proj.weight shape: [1024, 16384]
+1133: model.layers.99.self_attn.o_proj.weight shape: [16384, 16384]
+1134: model.layers.99.self_attn.q_proj.weight shape: [16384, 16384]
+1135: model.layers.99.self_attn.v_proj.weight shape: [1024, 16384]
+1136: model.norm.weight shape: [16384]
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_70b_ShapeTest.approved.txt b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_70b_ShapeTest.approved.txt
new file mode 100644
index 0000000000..5add8770c5
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_70b_ShapeTest.approved.txt
@@ -0,0 +1,723 @@
+0: lm_head.weight shape: [128256, 8192]
+1: model.embed_tokens.weight shape: [128256, 8192]
+2: model.layers.0.input_layernorm.weight shape: [8192]
+3: model.layers.0.mlp.down_proj.weight shape: [8192, 28672]
+4: model.layers.0.mlp.gate_proj.weight shape: [28672, 8192]
+5: model.layers.0.mlp.up_proj.weight shape: [28672, 8192]
+6: model.layers.0.post_attention_layernorm.weight shape: [8192]
+7: model.layers.0.self_attn.k_proj.weight shape: [1024, 8192]
+8: model.layers.0.self_attn.o_proj.weight shape: [8192, 8192]
+9: model.layers.0.self_attn.q_proj.weight shape: [8192, 8192]
+10: model.layers.0.self_attn.v_proj.weight shape: [1024, 8192]
+11: model.layers.1.input_layernorm.weight shape: [8192]
+12: model.layers.1.mlp.down_proj.weight shape: [8192, 28672]
+13: model.layers.1.mlp.gate_proj.weight shape: [28672, 8192]
+14: model.layers.1.mlp.up_proj.weight shape: [28672, 8192]
+15: model.layers.1.post_attention_layernorm.weight shape: [8192]
+16: model.layers.1.self_attn.k_proj.weight shape: [1024, 8192]
+17: model.layers.1.self_attn.o_proj.weight shape: [8192, 8192]
+18: model.layers.1.self_attn.q_proj.weight shape: [8192, 8192]
+19: model.layers.1.self_attn.v_proj.weight shape: [1024, 8192]
+20: model.layers.10.input_layernorm.weight shape: [8192]
+21: model.layers.10.mlp.down_proj.weight shape: [8192, 28672]
+22: model.layers.10.mlp.gate_proj.weight shape: [28672, 8192]
+23: model.layers.10.mlp.up_proj.weight shape: [28672, 8192]
+24: model.layers.10.post_attention_layernorm.weight shape: [8192]
+25: model.layers.10.self_attn.k_proj.weight shape: [1024, 8192]
+26: model.layers.10.self_attn.o_proj.weight shape: [8192, 8192]
+27: model.layers.10.self_attn.q_proj.weight shape: [8192, 8192]
+28: model.layers.10.self_attn.v_proj.weight shape: [1024, 8192]
+29: model.layers.11.input_layernorm.weight shape: [8192]
+30: model.layers.11.mlp.down_proj.weight shape: [8192, 28672]
+31: model.layers.11.mlp.gate_proj.weight shape: [28672, 8192]
+32: model.layers.11.mlp.up_proj.weight shape: [28672, 8192]
+33: model.layers.11.post_attention_layernorm.weight shape: [8192]
+34: model.layers.11.self_attn.k_proj.weight shape: [1024, 8192]
+35: model.layers.11.self_attn.o_proj.weight shape: [8192, 8192]
+36: model.layers.11.self_attn.q_proj.weight shape: [8192, 8192]
+37: model.layers.11.self_attn.v_proj.weight shape: [1024, 8192]
+38: model.layers.12.input_layernorm.weight shape: [8192]
+39: model.layers.12.mlp.down_proj.weight shape: [8192, 28672]
+40: model.layers.12.mlp.gate_proj.weight shape: [28672, 8192]
+41: model.layers.12.mlp.up_proj.weight shape: [28672, 8192]
+42: model.layers.12.post_attention_layernorm.weight shape: [8192]
+43: model.layers.12.self_attn.k_proj.weight shape: [1024, 8192]
+44: model.layers.12.self_attn.o_proj.weight shape: [8192, 8192]
+45: model.layers.12.self_attn.q_proj.weight shape: [8192, 8192]
+46: model.layers.12.self_attn.v_proj.weight shape: [1024, 8192]
+47: model.layers.13.input_layernorm.weight shape: [8192]
+48: model.layers.13.mlp.down_proj.weight shape: [8192, 28672]
+49: model.layers.13.mlp.gate_proj.weight shape: [28672, 8192]
+50: model.layers.13.mlp.up_proj.weight shape: [28672, 8192]
+51: model.layers.13.post_attention_layernorm.weight shape: [8192]
+52: model.layers.13.self_attn.k_proj.weight shape: [1024, 8192]
+53: model.layers.13.self_attn.o_proj.weight shape: [8192, 8192]
+54: model.layers.13.self_attn.q_proj.weight shape: [8192, 8192]
+55: model.layers.13.self_attn.v_proj.weight shape: [1024, 8192]
+56: model.layers.14.input_layernorm.weight shape: [8192]
+57: model.layers.14.mlp.down_proj.weight shape: [8192, 28672]
+58: model.layers.14.mlp.gate_proj.weight shape: [28672, 8192]
+59: model.layers.14.mlp.up_proj.weight shape: [28672, 8192]
+60: model.layers.14.post_attention_layernorm.weight shape: [8192]
+61: model.layers.14.self_attn.k_proj.weight shape: [1024, 8192]
+62: model.layers.14.self_attn.o_proj.weight shape: [8192, 8192]
+63: model.layers.14.self_attn.q_proj.weight shape: [8192, 8192]
+64: model.layers.14.self_attn.v_proj.weight shape: [1024, 8192]
+65: model.layers.15.input_layernorm.weight shape: [8192]
+66: model.layers.15.mlp.down_proj.weight shape: [8192, 28672]
+67: model.layers.15.mlp.gate_proj.weight shape: [28672, 8192]
+68: model.layers.15.mlp.up_proj.weight shape: [28672, 8192]
+69: model.layers.15.post_attention_layernorm.weight shape: [8192]
+70: model.layers.15.self_attn.k_proj.weight shape: [1024, 8192]
+71: model.layers.15.self_attn.o_proj.weight shape: [8192, 8192]
+72: model.layers.15.self_attn.q_proj.weight shape: [8192, 8192]
+73: model.layers.15.self_attn.v_proj.weight shape: [1024, 8192]
+74: model.layers.16.input_layernorm.weight shape: [8192]
+75: model.layers.16.mlp.down_proj.weight shape: [8192, 28672]
+76: model.layers.16.mlp.gate_proj.weight shape: [28672, 8192]
+77: model.layers.16.mlp.up_proj.weight shape: [28672, 8192]
+78: model.layers.16.post_attention_layernorm.weight shape: [8192]
+79: model.layers.16.self_attn.k_proj.weight shape: [1024, 8192]
+80: model.layers.16.self_attn.o_proj.weight shape: [8192, 8192]
+81: model.layers.16.self_attn.q_proj.weight shape: [8192, 8192]
+82: model.layers.16.self_attn.v_proj.weight shape: [1024, 8192]
+83: model.layers.17.input_layernorm.weight shape: [8192]
+84: model.layers.17.mlp.down_proj.weight shape: [8192, 28672]
+85: model.layers.17.mlp.gate_proj.weight shape: [28672, 8192]
+86: model.layers.17.mlp.up_proj.weight shape: [28672, 8192]
+87: model.layers.17.post_attention_layernorm.weight shape: [8192]
+88: model.layers.17.self_attn.k_proj.weight shape: [1024, 8192]
+89: model.layers.17.self_attn.o_proj.weight shape: [8192, 8192]
+90: model.layers.17.self_attn.q_proj.weight shape: [8192, 8192]
+91: model.layers.17.self_attn.v_proj.weight shape: [1024, 8192]
+92: model.layers.18.input_layernorm.weight shape: [8192]
+93: model.layers.18.mlp.down_proj.weight shape: [8192, 28672]
+94: model.layers.18.mlp.gate_proj.weight shape: [28672, 8192]
+95: model.layers.18.mlp.up_proj.weight shape: [28672, 8192]
+96: model.layers.18.post_attention_layernorm.weight shape: [8192]
+97: model.layers.18.self_attn.k_proj.weight shape: [1024, 8192]
+98: model.layers.18.self_attn.o_proj.weight shape: [8192, 8192]
+99: model.layers.18.self_attn.q_proj.weight shape: [8192, 8192]
+100: model.layers.18.self_attn.v_proj.weight shape: [1024, 8192]
+101: model.layers.19.input_layernorm.weight shape: [8192]
+102: model.layers.19.mlp.down_proj.weight shape: [8192, 28672]
+103: model.layers.19.mlp.gate_proj.weight shape: [28672, 8192]
+104: model.layers.19.mlp.up_proj.weight shape: [28672, 8192]
+105: model.layers.19.post_attention_layernorm.weight shape: [8192]
+106: model.layers.19.self_attn.k_proj.weight shape: [1024, 8192]
+107: model.layers.19.self_attn.o_proj.weight shape: [8192, 8192]
+108: model.layers.19.self_attn.q_proj.weight shape: [8192, 8192]
+109: model.layers.19.self_attn.v_proj.weight shape: [1024, 8192]
+110: model.layers.2.input_layernorm.weight shape: [8192]
+111: model.layers.2.mlp.down_proj.weight shape: [8192, 28672]
+112: model.layers.2.mlp.gate_proj.weight shape: [28672, 8192]
+113: model.layers.2.mlp.up_proj.weight shape: [28672, 8192]
+114: model.layers.2.post_attention_layernorm.weight shape: [8192]
+115: model.layers.2.self_attn.k_proj.weight shape: [1024, 8192]
+116: model.layers.2.self_attn.o_proj.weight shape: [8192, 8192]
+117: model.layers.2.self_attn.q_proj.weight shape: [8192, 8192]
+118: model.layers.2.self_attn.v_proj.weight shape: [1024, 8192]
+119: model.layers.20.input_layernorm.weight shape: [8192]
+120: model.layers.20.mlp.down_proj.weight shape: [8192, 28672]
+121: model.layers.20.mlp.gate_proj.weight shape: [28672, 8192]
+122: model.layers.20.mlp.up_proj.weight shape: [28672, 8192]
+123: model.layers.20.post_attention_layernorm.weight shape: [8192]
+124: model.layers.20.self_attn.k_proj.weight shape: [1024, 8192]
+125: model.layers.20.self_attn.o_proj.weight shape: [8192, 8192]
+126: model.layers.20.self_attn.q_proj.weight shape: [8192, 8192]
+127: model.layers.20.self_attn.v_proj.weight shape: [1024, 8192]
+128: model.layers.21.input_layernorm.weight shape: [8192]
+129: model.layers.21.mlp.down_proj.weight shape: [8192, 28672]
+130: model.layers.21.mlp.gate_proj.weight shape: [28672, 8192]
+131: model.layers.21.mlp.up_proj.weight shape: [28672, 8192]
+132: model.layers.21.post_attention_layernorm.weight shape: [8192]
+133: model.layers.21.self_attn.k_proj.weight shape: [1024, 8192]
+134: model.layers.21.self_attn.o_proj.weight shape: [8192, 8192]
+135: model.layers.21.self_attn.q_proj.weight shape: [8192, 8192]
+136: model.layers.21.self_attn.v_proj.weight shape: [1024, 8192]
+137: model.layers.22.input_layernorm.weight shape: [8192]
+138: model.layers.22.mlp.down_proj.weight shape: [8192, 28672]
+139: model.layers.22.mlp.gate_proj.weight shape: [28672, 8192]
+140: model.layers.22.mlp.up_proj.weight shape: [28672, 8192]
+141: model.layers.22.post_attention_layernorm.weight shape: [8192]
+142: model.layers.22.self_attn.k_proj.weight shape: [1024, 8192]
+143: model.layers.22.self_attn.o_proj.weight shape: [8192, 8192]
+144: model.layers.22.self_attn.q_proj.weight shape: [8192, 8192]
+145: model.layers.22.self_attn.v_proj.weight shape: [1024, 8192]
+146: model.layers.23.input_layernorm.weight shape: [8192]
+147: model.layers.23.mlp.down_proj.weight shape: [8192, 28672]
+148: model.layers.23.mlp.gate_proj.weight shape: [28672, 8192]
+149: model.layers.23.mlp.up_proj.weight shape: [28672, 8192]
+150: model.layers.23.post_attention_layernorm.weight shape: [8192]
+151: model.layers.23.self_attn.k_proj.weight shape: [1024, 8192]
+152: model.layers.23.self_attn.o_proj.weight shape: [8192, 8192]
+153: model.layers.23.self_attn.q_proj.weight shape: [8192, 8192]
+154: model.layers.23.self_attn.v_proj.weight shape: [1024, 8192]
+155: model.layers.24.input_layernorm.weight shape: [8192]
+156: model.layers.24.mlp.down_proj.weight shape: [8192, 28672]
+157: model.layers.24.mlp.gate_proj.weight shape: [28672, 8192]
+158: model.layers.24.mlp.up_proj.weight shape: [28672, 8192]
+159: model.layers.24.post_attention_layernorm.weight shape: [8192]
+160: model.layers.24.self_attn.k_proj.weight shape: [1024, 8192]
+161: model.layers.24.self_attn.o_proj.weight shape: [8192, 8192]
+162: model.layers.24.self_attn.q_proj.weight shape: [8192, 8192]
+163: model.layers.24.self_attn.v_proj.weight shape: [1024, 8192]
+164: model.layers.25.input_layernorm.weight shape: [8192]
+165: model.layers.25.mlp.down_proj.weight shape: [8192, 28672]
+166: model.layers.25.mlp.gate_proj.weight shape: [28672, 8192]
+167: model.layers.25.mlp.up_proj.weight shape: [28672, 8192]
+168: model.layers.25.post_attention_layernorm.weight shape: [8192]
+169: model.layers.25.self_attn.k_proj.weight shape: [1024, 8192]
+170: model.layers.25.self_attn.o_proj.weight shape: [8192, 8192]
+171: model.layers.25.self_attn.q_proj.weight shape: [8192, 8192]
+172: model.layers.25.self_attn.v_proj.weight shape: [1024, 8192]
+173: model.layers.26.input_layernorm.weight shape: [8192]
+174: model.layers.26.mlp.down_proj.weight shape: [8192, 28672]
+175: model.layers.26.mlp.gate_proj.weight shape: [28672, 8192]
+176: model.layers.26.mlp.up_proj.weight shape: [28672, 8192]
+177: model.layers.26.post_attention_layernorm.weight shape: [8192]
+178: model.layers.26.self_attn.k_proj.weight shape: [1024, 8192]
+179: model.layers.26.self_attn.o_proj.weight shape: [8192, 8192]
+180: model.layers.26.self_attn.q_proj.weight shape: [8192, 8192]
+181: model.layers.26.self_attn.v_proj.weight shape: [1024, 8192]
+182: model.layers.27.input_layernorm.weight shape: [8192]
+183: model.layers.27.mlp.down_proj.weight shape: [8192, 28672]
+184: model.layers.27.mlp.gate_proj.weight shape: [28672, 8192]
+185: model.layers.27.mlp.up_proj.weight shape: [28672, 8192]
+186: model.layers.27.post_attention_layernorm.weight shape: [8192]
+187: model.layers.27.self_attn.k_proj.weight shape: [1024, 8192]
+188: model.layers.27.self_attn.o_proj.weight shape: [8192, 8192]
+189: model.layers.27.self_attn.q_proj.weight shape: [8192, 8192]
+190: model.layers.27.self_attn.v_proj.weight shape: [1024, 8192]
+191: model.layers.28.input_layernorm.weight shape: [8192]
+192: model.layers.28.mlp.down_proj.weight shape: [8192, 28672]
+193: model.layers.28.mlp.gate_proj.weight shape: [28672, 8192]
+194: model.layers.28.mlp.up_proj.weight shape: [28672, 8192]
+195: model.layers.28.post_attention_layernorm.weight shape: [8192]
+196: model.layers.28.self_attn.k_proj.weight shape: [1024, 8192]
+197: model.layers.28.self_attn.o_proj.weight shape: [8192, 8192]
+198: model.layers.28.self_attn.q_proj.weight shape: [8192, 8192]
+199: model.layers.28.self_attn.v_proj.weight shape: [1024, 8192]
+200: model.layers.29.input_layernorm.weight shape: [8192]
+201: model.layers.29.mlp.down_proj.weight shape: [8192, 28672]
+202: model.layers.29.mlp.gate_proj.weight shape: [28672, 8192]
+203: model.layers.29.mlp.up_proj.weight shape: [28672, 8192]
+204: model.layers.29.post_attention_layernorm.weight shape: [8192]
+205: model.layers.29.self_attn.k_proj.weight shape: [1024, 8192]
+206: model.layers.29.self_attn.o_proj.weight shape: [8192, 8192]
+207: model.layers.29.self_attn.q_proj.weight shape: [8192, 8192]
+208: model.layers.29.self_attn.v_proj.weight shape: [1024, 8192]
+209: model.layers.3.input_layernorm.weight shape: [8192]
+210: model.layers.3.mlp.down_proj.weight shape: [8192, 28672]
+211: model.layers.3.mlp.gate_proj.weight shape: [28672, 8192]
+212: model.layers.3.mlp.up_proj.weight shape: [28672, 8192]
+213: model.layers.3.post_attention_layernorm.weight shape: [8192]
+214: model.layers.3.self_attn.k_proj.weight shape: [1024, 8192]
+215: model.layers.3.self_attn.o_proj.weight shape: [8192, 8192]
+216: model.layers.3.self_attn.q_proj.weight shape: [8192, 8192]
+217: model.layers.3.self_attn.v_proj.weight shape: [1024, 8192]
+218: model.layers.30.input_layernorm.weight shape: [8192]
+219: model.layers.30.mlp.down_proj.weight shape: [8192, 28672]
+220: model.layers.30.mlp.gate_proj.weight shape: [28672, 8192]
+221: model.layers.30.mlp.up_proj.weight shape: [28672, 8192]
+222: model.layers.30.post_attention_layernorm.weight shape: [8192]
+223: model.layers.30.self_attn.k_proj.weight shape: [1024, 8192]
+224: model.layers.30.self_attn.o_proj.weight shape: [8192, 8192]
+225: model.layers.30.self_attn.q_proj.weight shape: [8192, 8192]
+226: model.layers.30.self_attn.v_proj.weight shape: [1024, 8192]
+227: model.layers.31.input_layernorm.weight shape: [8192]
+228: model.layers.31.mlp.down_proj.weight shape: [8192, 28672]
+229: model.layers.31.mlp.gate_proj.weight shape: [28672, 8192]
+230: model.layers.31.mlp.up_proj.weight shape: [28672, 8192]
+231: model.layers.31.post_attention_layernorm.weight shape: [8192]
+232: model.layers.31.self_attn.k_proj.weight shape: [1024, 8192]
+233: model.layers.31.self_attn.o_proj.weight shape: [8192, 8192]
+234: model.layers.31.self_attn.q_proj.weight shape: [8192, 8192]
+235: model.layers.31.self_attn.v_proj.weight shape: [1024, 8192]
+236: model.layers.32.input_layernorm.weight shape: [8192]
+237: model.layers.32.mlp.down_proj.weight shape: [8192, 28672]
+238: model.layers.32.mlp.gate_proj.weight shape: [28672, 8192]
+239: model.layers.32.mlp.up_proj.weight shape: [28672, 8192]
+240: model.layers.32.post_attention_layernorm.weight shape: [8192]
+241: model.layers.32.self_attn.k_proj.weight shape: [1024, 8192]
+242: model.layers.32.self_attn.o_proj.weight shape: [8192, 8192]
+243: model.layers.32.self_attn.q_proj.weight shape: [8192, 8192]
+244: model.layers.32.self_attn.v_proj.weight shape: [1024, 8192]
+245: model.layers.33.input_layernorm.weight shape: [8192]
+246: model.layers.33.mlp.down_proj.weight shape: [8192, 28672]
+247: model.layers.33.mlp.gate_proj.weight shape: [28672, 8192]
+248: model.layers.33.mlp.up_proj.weight shape: [28672, 8192]
+249: model.layers.33.post_attention_layernorm.weight shape: [8192]
+250: model.layers.33.self_attn.k_proj.weight shape: [1024, 8192]
+251: model.layers.33.self_attn.o_proj.weight shape: [8192, 8192]
+252: model.layers.33.self_attn.q_proj.weight shape: [8192, 8192]
+253: model.layers.33.self_attn.v_proj.weight shape: [1024, 8192]
+254: model.layers.34.input_layernorm.weight shape: [8192]
+255: model.layers.34.mlp.down_proj.weight shape: [8192, 28672]
+256: model.layers.34.mlp.gate_proj.weight shape: [28672, 8192]
+257: model.layers.34.mlp.up_proj.weight shape: [28672, 8192]
+258: model.layers.34.post_attention_layernorm.weight shape: [8192]
+259: model.layers.34.self_attn.k_proj.weight shape: [1024, 8192]
+260: model.layers.34.self_attn.o_proj.weight shape: [8192, 8192]
+261: model.layers.34.self_attn.q_proj.weight shape: [8192, 8192]
+262: model.layers.34.self_attn.v_proj.weight shape: [1024, 8192]
+263: model.layers.35.input_layernorm.weight shape: [8192]
+264: model.layers.35.mlp.down_proj.weight shape: [8192, 28672]
+265: model.layers.35.mlp.gate_proj.weight shape: [28672, 8192]
+266: model.layers.35.mlp.up_proj.weight shape: [28672, 8192]
+267: model.layers.35.post_attention_layernorm.weight shape: [8192]
+268: model.layers.35.self_attn.k_proj.weight shape: [1024, 8192]
+269: model.layers.35.self_attn.o_proj.weight shape: [8192, 8192]
+270: model.layers.35.self_attn.q_proj.weight shape: [8192, 8192]
+271: model.layers.35.self_attn.v_proj.weight shape: [1024, 8192]
+272: model.layers.36.input_layernorm.weight shape: [8192]
+273: model.layers.36.mlp.down_proj.weight shape: [8192, 28672]
+274: model.layers.36.mlp.gate_proj.weight shape: [28672, 8192]
+275: model.layers.36.mlp.up_proj.weight shape: [28672, 8192]
+276: model.layers.36.post_attention_layernorm.weight shape: [8192]
+277: model.layers.36.self_attn.k_proj.weight shape: [1024, 8192]
+278: model.layers.36.self_attn.o_proj.weight shape: [8192, 8192]
+279: model.layers.36.self_attn.q_proj.weight shape: [8192, 8192]
+280: model.layers.36.self_attn.v_proj.weight shape: [1024, 8192]
+281: model.layers.37.input_layernorm.weight shape: [8192]
+282: model.layers.37.mlp.down_proj.weight shape: [8192, 28672]
+283: model.layers.37.mlp.gate_proj.weight shape: [28672, 8192]
+284: model.layers.37.mlp.up_proj.weight shape: [28672, 8192]
+285: model.layers.37.post_attention_layernorm.weight shape: [8192]
+286: model.layers.37.self_attn.k_proj.weight shape: [1024, 8192]
+287: model.layers.37.self_attn.o_proj.weight shape: [8192, 8192]
+288: model.layers.37.self_attn.q_proj.weight shape: [8192, 8192]
+289: model.layers.37.self_attn.v_proj.weight shape: [1024, 8192]
+290: model.layers.38.input_layernorm.weight shape: [8192]
+291: model.layers.38.mlp.down_proj.weight shape: [8192, 28672]
+292: model.layers.38.mlp.gate_proj.weight shape: [28672, 8192]
+293: model.layers.38.mlp.up_proj.weight shape: [28672, 8192]
+294: model.layers.38.post_attention_layernorm.weight shape: [8192]
+295: model.layers.38.self_attn.k_proj.weight shape: [1024, 8192]
+296: model.layers.38.self_attn.o_proj.weight shape: [8192, 8192]
+297: model.layers.38.self_attn.q_proj.weight shape: [8192, 8192]
+298: model.layers.38.self_attn.v_proj.weight shape: [1024, 8192]
+299: model.layers.39.input_layernorm.weight shape: [8192]
+300: model.layers.39.mlp.down_proj.weight shape: [8192, 28672]
+301: model.layers.39.mlp.gate_proj.weight shape: [28672, 8192]
+302: model.layers.39.mlp.up_proj.weight shape: [28672, 8192]
+303: model.layers.39.post_attention_layernorm.weight shape: [8192]
+304: model.layers.39.self_attn.k_proj.weight shape: [1024, 8192]
+305: model.layers.39.self_attn.o_proj.weight shape: [8192, 8192]
+306: model.layers.39.self_attn.q_proj.weight shape: [8192, 8192]
+307: model.layers.39.self_attn.v_proj.weight shape: [1024, 8192]
+308: model.layers.4.input_layernorm.weight shape: [8192]
+309: model.layers.4.mlp.down_proj.weight shape: [8192, 28672]
+310: model.layers.4.mlp.gate_proj.weight shape: [28672, 8192]
+311: model.layers.4.mlp.up_proj.weight shape: [28672, 8192]
+312: model.layers.4.post_attention_layernorm.weight shape: [8192]
+313: model.layers.4.self_attn.k_proj.weight shape: [1024, 8192]
+314: model.layers.4.self_attn.o_proj.weight shape: [8192, 8192]
+315: model.layers.4.self_attn.q_proj.weight shape: [8192, 8192]
+316: model.layers.4.self_attn.v_proj.weight shape: [1024, 8192]
+317: model.layers.40.input_layernorm.weight shape: [8192]
+318: model.layers.40.mlp.down_proj.weight shape: [8192, 28672]
+319: model.layers.40.mlp.gate_proj.weight shape: [28672, 8192]
+320: model.layers.40.mlp.up_proj.weight shape: [28672, 8192]
+321: model.layers.40.post_attention_layernorm.weight shape: [8192]
+322: model.layers.40.self_attn.k_proj.weight shape: [1024, 8192]
+323: model.layers.40.self_attn.o_proj.weight shape: [8192, 8192]
+324: model.layers.40.self_attn.q_proj.weight shape: [8192, 8192]
+325: model.layers.40.self_attn.v_proj.weight shape: [1024, 8192]
+326: model.layers.41.input_layernorm.weight shape: [8192]
+327: model.layers.41.mlp.down_proj.weight shape: [8192, 28672]
+328: model.layers.41.mlp.gate_proj.weight shape: [28672, 8192]
+329: model.layers.41.mlp.up_proj.weight shape: [28672, 8192]
+330: model.layers.41.post_attention_layernorm.weight shape: [8192]
+331: model.layers.41.self_attn.k_proj.weight shape: [1024, 8192]
+332: model.layers.41.self_attn.o_proj.weight shape: [8192, 8192]
+333: model.layers.41.self_attn.q_proj.weight shape: [8192, 8192]
+334: model.layers.41.self_attn.v_proj.weight shape: [1024, 8192]
+335: model.layers.42.input_layernorm.weight shape: [8192]
+336: model.layers.42.mlp.down_proj.weight shape: [8192, 28672]
+337: model.layers.42.mlp.gate_proj.weight shape: [28672, 8192]
+338: model.layers.42.mlp.up_proj.weight shape: [28672, 8192]
+339: model.layers.42.post_attention_layernorm.weight shape: [8192]
+340: model.layers.42.self_attn.k_proj.weight shape: [1024, 8192]
+341: model.layers.42.self_attn.o_proj.weight shape: [8192, 8192]
+342: model.layers.42.self_attn.q_proj.weight shape: [8192, 8192]
+343: model.layers.42.self_attn.v_proj.weight shape: [1024, 8192]
+344: model.layers.43.input_layernorm.weight shape: [8192]
+345: model.layers.43.mlp.down_proj.weight shape: [8192, 28672]
+346: model.layers.43.mlp.gate_proj.weight shape: [28672, 8192]
+347: model.layers.43.mlp.up_proj.weight shape: [28672, 8192]
+348: model.layers.43.post_attention_layernorm.weight shape: [8192]
+349: model.layers.43.self_attn.k_proj.weight shape: [1024, 8192]
+350: model.layers.43.self_attn.o_proj.weight shape: [8192, 8192]
+351: model.layers.43.self_attn.q_proj.weight shape: [8192, 8192]
+352: model.layers.43.self_attn.v_proj.weight shape: [1024, 8192]
+353: model.layers.44.input_layernorm.weight shape: [8192]
+354: model.layers.44.mlp.down_proj.weight shape: [8192, 28672]
+355: model.layers.44.mlp.gate_proj.weight shape: [28672, 8192]
+356: model.layers.44.mlp.up_proj.weight shape: [28672, 8192]
+357: model.layers.44.post_attention_layernorm.weight shape: [8192]
+358: model.layers.44.self_attn.k_proj.weight shape: [1024, 8192]
+359: model.layers.44.self_attn.o_proj.weight shape: [8192, 8192]
+360: model.layers.44.self_attn.q_proj.weight shape: [8192, 8192]
+361: model.layers.44.self_attn.v_proj.weight shape: [1024, 8192]
+362: model.layers.45.input_layernorm.weight shape: [8192]
+363: model.layers.45.mlp.down_proj.weight shape: [8192, 28672]
+364: model.layers.45.mlp.gate_proj.weight shape: [28672, 8192]
+365: model.layers.45.mlp.up_proj.weight shape: [28672, 8192]
+366: model.layers.45.post_attention_layernorm.weight shape: [8192]
+367: model.layers.45.self_attn.k_proj.weight shape: [1024, 8192]
+368: model.layers.45.self_attn.o_proj.weight shape: [8192, 8192]
+369: model.layers.45.self_attn.q_proj.weight shape: [8192, 8192]
+370: model.layers.45.self_attn.v_proj.weight shape: [1024, 8192]
+371: model.layers.46.input_layernorm.weight shape: [8192]
+372: model.layers.46.mlp.down_proj.weight shape: [8192, 28672]
+373: model.layers.46.mlp.gate_proj.weight shape: [28672, 8192]
+374: model.layers.46.mlp.up_proj.weight shape: [28672, 8192]
+375: model.layers.46.post_attention_layernorm.weight shape: [8192]
+376: model.layers.46.self_attn.k_proj.weight shape: [1024, 8192]
+377: model.layers.46.self_attn.o_proj.weight shape: [8192, 8192]
+378: model.layers.46.self_attn.q_proj.weight shape: [8192, 8192]
+379: model.layers.46.self_attn.v_proj.weight shape: [1024, 8192]
+380: model.layers.47.input_layernorm.weight shape: [8192]
+381: model.layers.47.mlp.down_proj.weight shape: [8192, 28672]
+382: model.layers.47.mlp.gate_proj.weight shape: [28672, 8192]
+383: model.layers.47.mlp.up_proj.weight shape: [28672, 8192]
+384: model.layers.47.post_attention_layernorm.weight shape: [8192]
+385: model.layers.47.self_attn.k_proj.weight shape: [1024, 8192]
+386: model.layers.47.self_attn.o_proj.weight shape: [8192, 8192]
+387: model.layers.47.self_attn.q_proj.weight shape: [8192, 8192]
+388: model.layers.47.self_attn.v_proj.weight shape: [1024, 8192]
+389: model.layers.48.input_layernorm.weight shape: [8192]
+390: model.layers.48.mlp.down_proj.weight shape: [8192, 28672]
+391: model.layers.48.mlp.gate_proj.weight shape: [28672, 8192]
+392: model.layers.48.mlp.up_proj.weight shape: [28672, 8192]
+393: model.layers.48.post_attention_layernorm.weight shape: [8192]
+394: model.layers.48.self_attn.k_proj.weight shape: [1024, 8192]
+395: model.layers.48.self_attn.o_proj.weight shape: [8192, 8192]
+396: model.layers.48.self_attn.q_proj.weight shape: [8192, 8192]
+397: model.layers.48.self_attn.v_proj.weight shape: [1024, 8192]
+398: model.layers.49.input_layernorm.weight shape: [8192]
+399: model.layers.49.mlp.down_proj.weight shape: [8192, 28672]
+400: model.layers.49.mlp.gate_proj.weight shape: [28672, 8192]
+401: model.layers.49.mlp.up_proj.weight shape: [28672, 8192]
+402: model.layers.49.post_attention_layernorm.weight shape: [8192]
+403: model.layers.49.self_attn.k_proj.weight shape: [1024, 8192]
+404: model.layers.49.self_attn.o_proj.weight shape: [8192, 8192]
+405: model.layers.49.self_attn.q_proj.weight shape: [8192, 8192]
+406: model.layers.49.self_attn.v_proj.weight shape: [1024, 8192]
+407: model.layers.5.input_layernorm.weight shape: [8192]
+408: model.layers.5.mlp.down_proj.weight shape: [8192, 28672]
+409: model.layers.5.mlp.gate_proj.weight shape: [28672, 8192]
+410: model.layers.5.mlp.up_proj.weight shape: [28672, 8192]
+411: model.layers.5.post_attention_layernorm.weight shape: [8192]
+412: model.layers.5.self_attn.k_proj.weight shape: [1024, 8192]
+413: model.layers.5.self_attn.o_proj.weight shape: [8192, 8192]
+414: model.layers.5.self_attn.q_proj.weight shape: [8192, 8192]
+415: model.layers.5.self_attn.v_proj.weight shape: [1024, 8192]
+416: model.layers.50.input_layernorm.weight shape: [8192]
+417: model.layers.50.mlp.down_proj.weight shape: [8192, 28672]
+418: model.layers.50.mlp.gate_proj.weight shape: [28672, 8192]
+419: model.layers.50.mlp.up_proj.weight shape: [28672, 8192]
+420: model.layers.50.post_attention_layernorm.weight shape: [8192]
+421: model.layers.50.self_attn.k_proj.weight shape: [1024, 8192]
+422: model.layers.50.self_attn.o_proj.weight shape: [8192, 8192]
+423: model.layers.50.self_attn.q_proj.weight shape: [8192, 8192]
+424: model.layers.50.self_attn.v_proj.weight shape: [1024, 8192]
+425: model.layers.51.input_layernorm.weight shape: [8192]
+426: model.layers.51.mlp.down_proj.weight shape: [8192, 28672]
+427: model.layers.51.mlp.gate_proj.weight shape: [28672, 8192]
+428: model.layers.51.mlp.up_proj.weight shape: [28672, 8192]
+429: model.layers.51.post_attention_layernorm.weight shape: [8192]
+430: model.layers.51.self_attn.k_proj.weight shape: [1024, 8192]
+431: model.layers.51.self_attn.o_proj.weight shape: [8192, 8192]
+432: model.layers.51.self_attn.q_proj.weight shape: [8192, 8192]
+433: model.layers.51.self_attn.v_proj.weight shape: [1024, 8192]
+434: model.layers.52.input_layernorm.weight shape: [8192]
+435: model.layers.52.mlp.down_proj.weight shape: [8192, 28672]
+436: model.layers.52.mlp.gate_proj.weight shape: [28672, 8192]
+437: model.layers.52.mlp.up_proj.weight shape: [28672, 8192]
+438: model.layers.52.post_attention_layernorm.weight shape: [8192]
+439: model.layers.52.self_attn.k_proj.weight shape: [1024, 8192]
+440: model.layers.52.self_attn.o_proj.weight shape: [8192, 8192]
+441: model.layers.52.self_attn.q_proj.weight shape: [8192, 8192]
+442: model.layers.52.self_attn.v_proj.weight shape: [1024, 8192]
+443: model.layers.53.input_layernorm.weight shape: [8192]
+444: model.layers.53.mlp.down_proj.weight shape: [8192, 28672]
+445: model.layers.53.mlp.gate_proj.weight shape: [28672, 8192]
+446: model.layers.53.mlp.up_proj.weight shape: [28672, 8192]
+447: model.layers.53.post_attention_layernorm.weight shape: [8192]
+448: model.layers.53.self_attn.k_proj.weight shape: [1024, 8192]
+449: model.layers.53.self_attn.o_proj.weight shape: [8192, 8192]
+450: model.layers.53.self_attn.q_proj.weight shape: [8192, 8192]
+451: model.layers.53.self_attn.v_proj.weight shape: [1024, 8192]
+452: model.layers.54.input_layernorm.weight shape: [8192]
+453: model.layers.54.mlp.down_proj.weight shape: [8192, 28672]
+454: model.layers.54.mlp.gate_proj.weight shape: [28672, 8192]
+455: model.layers.54.mlp.up_proj.weight shape: [28672, 8192]
+456: model.layers.54.post_attention_layernorm.weight shape: [8192]
+457: model.layers.54.self_attn.k_proj.weight shape: [1024, 8192]
+458: model.layers.54.self_attn.o_proj.weight shape: [8192, 8192]
+459: model.layers.54.self_attn.q_proj.weight shape: [8192, 8192]
+460: model.layers.54.self_attn.v_proj.weight shape: [1024, 8192]
+461: model.layers.55.input_layernorm.weight shape: [8192]
+462: model.layers.55.mlp.down_proj.weight shape: [8192, 28672]
+463: model.layers.55.mlp.gate_proj.weight shape: [28672, 8192]
+464: model.layers.55.mlp.up_proj.weight shape: [28672, 8192]
+465: model.layers.55.post_attention_layernorm.weight shape: [8192]
+466: model.layers.55.self_attn.k_proj.weight shape: [1024, 8192]
+467: model.layers.55.self_attn.o_proj.weight shape: [8192, 8192]
+468: model.layers.55.self_attn.q_proj.weight shape: [8192, 8192]
+469: model.layers.55.self_attn.v_proj.weight shape: [1024, 8192]
+470: model.layers.56.input_layernorm.weight shape: [8192]
+471: model.layers.56.mlp.down_proj.weight shape: [8192, 28672]
+472: model.layers.56.mlp.gate_proj.weight shape: [28672, 8192]
+473: model.layers.56.mlp.up_proj.weight shape: [28672, 8192]
+474: model.layers.56.post_attention_layernorm.weight shape: [8192]
+475: model.layers.56.self_attn.k_proj.weight shape: [1024, 8192]
+476: model.layers.56.self_attn.o_proj.weight shape: [8192, 8192]
+477: model.layers.56.self_attn.q_proj.weight shape: [8192, 8192]
+478: model.layers.56.self_attn.v_proj.weight shape: [1024, 8192]
+479: model.layers.57.input_layernorm.weight shape: [8192]
+480: model.layers.57.mlp.down_proj.weight shape: [8192, 28672]
+481: model.layers.57.mlp.gate_proj.weight shape: [28672, 8192]
+482: model.layers.57.mlp.up_proj.weight shape: [28672, 8192]
+483: model.layers.57.post_attention_layernorm.weight shape: [8192]
+484: model.layers.57.self_attn.k_proj.weight shape: [1024, 8192]
+485: model.layers.57.self_attn.o_proj.weight shape: [8192, 8192]
+486: model.layers.57.self_attn.q_proj.weight shape: [8192, 8192]
+487: model.layers.57.self_attn.v_proj.weight shape: [1024, 8192]
+488: model.layers.58.input_layernorm.weight shape: [8192]
+489: model.layers.58.mlp.down_proj.weight shape: [8192, 28672]
+490: model.layers.58.mlp.gate_proj.weight shape: [28672, 8192]
+491: model.layers.58.mlp.up_proj.weight shape: [28672, 8192]
+492: model.layers.58.post_attention_layernorm.weight shape: [8192]
+493: model.layers.58.self_attn.k_proj.weight shape: [1024, 8192]
+494: model.layers.58.self_attn.o_proj.weight shape: [8192, 8192]
+495: model.layers.58.self_attn.q_proj.weight shape: [8192, 8192]
+496: model.layers.58.self_attn.v_proj.weight shape: [1024, 8192]
+497: model.layers.59.input_layernorm.weight shape: [8192]
+498: model.layers.59.mlp.down_proj.weight shape: [8192, 28672]
+499: model.layers.59.mlp.gate_proj.weight shape: [28672, 8192]
+500: model.layers.59.mlp.up_proj.weight shape: [28672, 8192]
+501: model.layers.59.post_attention_layernorm.weight shape: [8192]
+502: model.layers.59.self_attn.k_proj.weight shape: [1024, 8192]
+503: model.layers.59.self_attn.o_proj.weight shape: [8192, 8192]
+504: model.layers.59.self_attn.q_proj.weight shape: [8192, 8192]
+505: model.layers.59.self_attn.v_proj.weight shape: [1024, 8192]
+506: model.layers.6.input_layernorm.weight shape: [8192]
+507: model.layers.6.mlp.down_proj.weight shape: [8192, 28672]
+508: model.layers.6.mlp.gate_proj.weight shape: [28672, 8192]
+509: model.layers.6.mlp.up_proj.weight shape: [28672, 8192]
+510: model.layers.6.post_attention_layernorm.weight shape: [8192]
+511: model.layers.6.self_attn.k_proj.weight shape: [1024, 8192]
+512: model.layers.6.self_attn.o_proj.weight shape: [8192, 8192]
+513: model.layers.6.self_attn.q_proj.weight shape: [8192, 8192]
+514: model.layers.6.self_attn.v_proj.weight shape: [1024, 8192]
+515: model.layers.60.input_layernorm.weight shape: [8192]
+516: model.layers.60.mlp.down_proj.weight shape: [8192, 28672]
+517: model.layers.60.mlp.gate_proj.weight shape: [28672, 8192]
+518: model.layers.60.mlp.up_proj.weight shape: [28672, 8192]
+519: model.layers.60.post_attention_layernorm.weight shape: [8192]
+520: model.layers.60.self_attn.k_proj.weight shape: [1024, 8192]
+521: model.layers.60.self_attn.o_proj.weight shape: [8192, 8192]
+522: model.layers.60.self_attn.q_proj.weight shape: [8192, 8192]
+523: model.layers.60.self_attn.v_proj.weight shape: [1024, 8192]
+524: model.layers.61.input_layernorm.weight shape: [8192]
+525: model.layers.61.mlp.down_proj.weight shape: [8192, 28672]
+526: model.layers.61.mlp.gate_proj.weight shape: [28672, 8192]
+527: model.layers.61.mlp.up_proj.weight shape: [28672, 8192]
+528: model.layers.61.post_attention_layernorm.weight shape: [8192]
+529: model.layers.61.self_attn.k_proj.weight shape: [1024, 8192]
+530: model.layers.61.self_attn.o_proj.weight shape: [8192, 8192]
+531: model.layers.61.self_attn.q_proj.weight shape: [8192, 8192]
+532: model.layers.61.self_attn.v_proj.weight shape: [1024, 8192]
+533: model.layers.62.input_layernorm.weight shape: [8192]
+534: model.layers.62.mlp.down_proj.weight shape: [8192, 28672]
+535: model.layers.62.mlp.gate_proj.weight shape: [28672, 8192]
+536: model.layers.62.mlp.up_proj.weight shape: [28672, 8192]
+537: model.layers.62.post_attention_layernorm.weight shape: [8192]
+538: model.layers.62.self_attn.k_proj.weight shape: [1024, 8192]
+539: model.layers.62.self_attn.o_proj.weight shape: [8192, 8192]
+540: model.layers.62.self_attn.q_proj.weight shape: [8192, 8192]
+541: model.layers.62.self_attn.v_proj.weight shape: [1024, 8192]
+542: model.layers.63.input_layernorm.weight shape: [8192]
+543: model.layers.63.mlp.down_proj.weight shape: [8192, 28672]
+544: model.layers.63.mlp.gate_proj.weight shape: [28672, 8192]
+545: model.layers.63.mlp.up_proj.weight shape: [28672, 8192]
+546: model.layers.63.post_attention_layernorm.weight shape: [8192]
+547: model.layers.63.self_attn.k_proj.weight shape: [1024, 8192]
+548: model.layers.63.self_attn.o_proj.weight shape: [8192, 8192]
+549: model.layers.63.self_attn.q_proj.weight shape: [8192, 8192]
+550: model.layers.63.self_attn.v_proj.weight shape: [1024, 8192]
+551: model.layers.64.input_layernorm.weight shape: [8192]
+552: model.layers.64.mlp.down_proj.weight shape: [8192, 28672]
+553: model.layers.64.mlp.gate_proj.weight shape: [28672, 8192]
+554: model.layers.64.mlp.up_proj.weight shape: [28672, 8192]
+555: model.layers.64.post_attention_layernorm.weight shape: [8192]
+556: model.layers.64.self_attn.k_proj.weight shape: [1024, 8192]
+557: model.layers.64.self_attn.o_proj.weight shape: [8192, 8192]
+558: model.layers.64.self_attn.q_proj.weight shape: [8192, 8192]
+559: model.layers.64.self_attn.v_proj.weight shape: [1024, 8192]
+560: model.layers.65.input_layernorm.weight shape: [8192]
+561: model.layers.65.mlp.down_proj.weight shape: [8192, 28672]
+562: model.layers.65.mlp.gate_proj.weight shape: [28672, 8192]
+563: model.layers.65.mlp.up_proj.weight shape: [28672, 8192]
+564: model.layers.65.post_attention_layernorm.weight shape: [8192]
+565: model.layers.65.self_attn.k_proj.weight shape: [1024, 8192]
+566: model.layers.65.self_attn.o_proj.weight shape: [8192, 8192]
+567: model.layers.65.self_attn.q_proj.weight shape: [8192, 8192]
+568: model.layers.65.self_attn.v_proj.weight shape: [1024, 8192]
+569: model.layers.66.input_layernorm.weight shape: [8192]
+570: model.layers.66.mlp.down_proj.weight shape: [8192, 28672]
+571: model.layers.66.mlp.gate_proj.weight shape: [28672, 8192]
+572: model.layers.66.mlp.up_proj.weight shape: [28672, 8192]
+573: model.layers.66.post_attention_layernorm.weight shape: [8192]
+574: model.layers.66.self_attn.k_proj.weight shape: [1024, 8192]
+575: model.layers.66.self_attn.o_proj.weight shape: [8192, 8192]
+576: model.layers.66.self_attn.q_proj.weight shape: [8192, 8192]
+577: model.layers.66.self_attn.v_proj.weight shape: [1024, 8192]
+578: model.layers.67.input_layernorm.weight shape: [8192]
+579: model.layers.67.mlp.down_proj.weight shape: [8192, 28672]
+580: model.layers.67.mlp.gate_proj.weight shape: [28672, 8192]
+581: model.layers.67.mlp.up_proj.weight shape: [28672, 8192]
+582: model.layers.67.post_attention_layernorm.weight shape: [8192]
+583: model.layers.67.self_attn.k_proj.weight shape: [1024, 8192]
+584: model.layers.67.self_attn.o_proj.weight shape: [8192, 8192]
+585: model.layers.67.self_attn.q_proj.weight shape: [8192, 8192]
+586: model.layers.67.self_attn.v_proj.weight shape: [1024, 8192]
+587: model.layers.68.input_layernorm.weight shape: [8192]
+588: model.layers.68.mlp.down_proj.weight shape: [8192, 28672]
+589: model.layers.68.mlp.gate_proj.weight shape: [28672, 8192]
+590: model.layers.68.mlp.up_proj.weight shape: [28672, 8192]
+591: model.layers.68.post_attention_layernorm.weight shape: [8192]
+592: model.layers.68.self_attn.k_proj.weight shape: [1024, 8192]
+593: model.layers.68.self_attn.o_proj.weight shape: [8192, 8192]
+594: model.layers.68.self_attn.q_proj.weight shape: [8192, 8192]
+595: model.layers.68.self_attn.v_proj.weight shape: [1024, 8192]
+596: model.layers.69.input_layernorm.weight shape: [8192]
+597: model.layers.69.mlp.down_proj.weight shape: [8192, 28672]
+598: model.layers.69.mlp.gate_proj.weight shape: [28672, 8192]
+599: model.layers.69.mlp.up_proj.weight shape: [28672, 8192]
+600: model.layers.69.post_attention_layernorm.weight shape: [8192]
+601: model.layers.69.self_attn.k_proj.weight shape: [1024, 8192]
+602: model.layers.69.self_attn.o_proj.weight shape: [8192, 8192]
+603: model.layers.69.self_attn.q_proj.weight shape: [8192, 8192]
+604: model.layers.69.self_attn.v_proj.weight shape: [1024, 8192]
+605: model.layers.7.input_layernorm.weight shape: [8192]
+606: model.layers.7.mlp.down_proj.weight shape: [8192, 28672]
+607: model.layers.7.mlp.gate_proj.weight shape: [28672, 8192]
+608: model.layers.7.mlp.up_proj.weight shape: [28672, 8192]
+609: model.layers.7.post_attention_layernorm.weight shape: [8192]
+610: model.layers.7.self_attn.k_proj.weight shape: [1024, 8192]
+611: model.layers.7.self_attn.o_proj.weight shape: [8192, 8192]
+612: model.layers.7.self_attn.q_proj.weight shape: [8192, 8192]
+613: model.layers.7.self_attn.v_proj.weight shape: [1024, 8192]
+614: model.layers.70.input_layernorm.weight shape: [8192]
+615: model.layers.70.mlp.down_proj.weight shape: [8192, 28672]
+616: model.layers.70.mlp.gate_proj.weight shape: [28672, 8192]
+617: model.layers.70.mlp.up_proj.weight shape: [28672, 8192]
+618: model.layers.70.post_attention_layernorm.weight shape: [8192]
+619: model.layers.70.self_attn.k_proj.weight shape: [1024, 8192]
+620: model.layers.70.self_attn.o_proj.weight shape: [8192, 8192]
+621: model.layers.70.self_attn.q_proj.weight shape: [8192, 8192]
+622: model.layers.70.self_attn.v_proj.weight shape: [1024, 8192]
+623: model.layers.71.input_layernorm.weight shape: [8192]
+624: model.layers.71.mlp.down_proj.weight shape: [8192, 28672]
+625: model.layers.71.mlp.gate_proj.weight shape: [28672, 8192]
+626: model.layers.71.mlp.up_proj.weight shape: [28672, 8192]
+627: model.layers.71.post_attention_layernorm.weight shape: [8192]
+628: model.layers.71.self_attn.k_proj.weight shape: [1024, 8192]
+629: model.layers.71.self_attn.o_proj.weight shape: [8192, 8192]
+630: model.layers.71.self_attn.q_proj.weight shape: [8192, 8192]
+631: model.layers.71.self_attn.v_proj.weight shape: [1024, 8192]
+632: model.layers.72.input_layernorm.weight shape: [8192]
+633: model.layers.72.mlp.down_proj.weight shape: [8192, 28672]
+634: model.layers.72.mlp.gate_proj.weight shape: [28672, 8192]
+635: model.layers.72.mlp.up_proj.weight shape: [28672, 8192]
+636: model.layers.72.post_attention_layernorm.weight shape: [8192]
+637: model.layers.72.self_attn.k_proj.weight shape: [1024, 8192]
+638: model.layers.72.self_attn.o_proj.weight shape: [8192, 8192]
+639: model.layers.72.self_attn.q_proj.weight shape: [8192, 8192]
+640: model.layers.72.self_attn.v_proj.weight shape: [1024, 8192]
+641: model.layers.73.input_layernorm.weight shape: [8192]
+642: model.layers.73.mlp.down_proj.weight shape: [8192, 28672]
+643: model.layers.73.mlp.gate_proj.weight shape: [28672, 8192]
+644: model.layers.73.mlp.up_proj.weight shape: [28672, 8192]
+645: model.layers.73.post_attention_layernorm.weight shape: [8192]
+646: model.layers.73.self_attn.k_proj.weight shape: [1024, 8192]
+647: model.layers.73.self_attn.o_proj.weight shape: [8192, 8192]
+648: model.layers.73.self_attn.q_proj.weight shape: [8192, 8192]
+649: model.layers.73.self_attn.v_proj.weight shape: [1024, 8192]
+650: model.layers.74.input_layernorm.weight shape: [8192]
+651: model.layers.74.mlp.down_proj.weight shape: [8192, 28672]
+652: model.layers.74.mlp.gate_proj.weight shape: [28672, 8192]
+653: model.layers.74.mlp.up_proj.weight shape: [28672, 8192]
+654: model.layers.74.post_attention_layernorm.weight shape: [8192]
+655: model.layers.74.self_attn.k_proj.weight shape: [1024, 8192]
+656: model.layers.74.self_attn.o_proj.weight shape: [8192, 8192]
+657: model.layers.74.self_attn.q_proj.weight shape: [8192, 8192]
+658: model.layers.74.self_attn.v_proj.weight shape: [1024, 8192]
+659: model.layers.75.input_layernorm.weight shape: [8192]
+660: model.layers.75.mlp.down_proj.weight shape: [8192, 28672]
+661: model.layers.75.mlp.gate_proj.weight shape: [28672, 8192]
+662: model.layers.75.mlp.up_proj.weight shape: [28672, 8192]
+663: model.layers.75.post_attention_layernorm.weight shape: [8192]
+664: model.layers.75.self_attn.k_proj.weight shape: [1024, 8192]
+665: model.layers.75.self_attn.o_proj.weight shape: [8192, 8192]
+666: model.layers.75.self_attn.q_proj.weight shape: [8192, 8192]
+667: model.layers.75.self_attn.v_proj.weight shape: [1024, 8192]
+668: model.layers.76.input_layernorm.weight shape: [8192]
+669: model.layers.76.mlp.down_proj.weight shape: [8192, 28672]
+670: model.layers.76.mlp.gate_proj.weight shape: [28672, 8192]
+671: model.layers.76.mlp.up_proj.weight shape: [28672, 8192]
+672: model.layers.76.post_attention_layernorm.weight shape: [8192]
+673: model.layers.76.self_attn.k_proj.weight shape: [1024, 8192]
+674: model.layers.76.self_attn.o_proj.weight shape: [8192, 8192]
+675: model.layers.76.self_attn.q_proj.weight shape: [8192, 8192]
+676: model.layers.76.self_attn.v_proj.weight shape: [1024, 8192]
+677: model.layers.77.input_layernorm.weight shape: [8192]
+678: model.layers.77.mlp.down_proj.weight shape: [8192, 28672]
+679: model.layers.77.mlp.gate_proj.weight shape: [28672, 8192]
+680: model.layers.77.mlp.up_proj.weight shape: [28672, 8192]
+681: model.layers.77.post_attention_layernorm.weight shape: [8192]
+682: model.layers.77.self_attn.k_proj.weight shape: [1024, 8192]
+683: model.layers.77.self_attn.o_proj.weight shape: [8192, 8192]
+684: model.layers.77.self_attn.q_proj.weight shape: [8192, 8192]
+685: model.layers.77.self_attn.v_proj.weight shape: [1024, 8192]
+686: model.layers.78.input_layernorm.weight shape: [8192]
+687: model.layers.78.mlp.down_proj.weight shape: [8192, 28672]
+688: model.layers.78.mlp.gate_proj.weight shape: [28672, 8192]
+689: model.layers.78.mlp.up_proj.weight shape: [28672, 8192]
+690: model.layers.78.post_attention_layernorm.weight shape: [8192]
+691: model.layers.78.self_attn.k_proj.weight shape: [1024, 8192]
+692: model.layers.78.self_attn.o_proj.weight shape: [8192, 8192]
+693: model.layers.78.self_attn.q_proj.weight shape: [8192, 8192]
+694: model.layers.78.self_attn.v_proj.weight shape: [1024, 8192]
+695: model.layers.79.input_layernorm.weight shape: [8192]
+696: model.layers.79.mlp.down_proj.weight shape: [8192, 28672]
+697: model.layers.79.mlp.gate_proj.weight shape: [28672, 8192]
+698: model.layers.79.mlp.up_proj.weight shape: [28672, 8192]
+699: model.layers.79.post_attention_layernorm.weight shape: [8192]
+700: model.layers.79.self_attn.k_proj.weight shape: [1024, 8192]
+701: model.layers.79.self_attn.o_proj.weight shape: [8192, 8192]
+702: model.layers.79.self_attn.q_proj.weight shape: [8192, 8192]
+703: model.layers.79.self_attn.v_proj.weight shape: [1024, 8192]
+704: model.layers.8.input_layernorm.weight shape: [8192]
+705: model.layers.8.mlp.down_proj.weight shape: [8192, 28672]
+706: model.layers.8.mlp.gate_proj.weight shape: [28672, 8192]
+707: model.layers.8.mlp.up_proj.weight shape: [28672, 8192]
+708: model.layers.8.post_attention_layernorm.weight shape: [8192]
+709: model.layers.8.self_attn.k_proj.weight shape: [1024, 8192]
+710: model.layers.8.self_attn.o_proj.weight shape: [8192, 8192]
+711: model.layers.8.self_attn.q_proj.weight shape: [8192, 8192]
+712: model.layers.8.self_attn.v_proj.weight shape: [1024, 8192]
+713: model.layers.9.input_layernorm.weight shape: [8192]
+714: model.layers.9.mlp.down_proj.weight shape: [8192, 28672]
+715: model.layers.9.mlp.gate_proj.weight shape: [28672, 8192]
+716: model.layers.9.mlp.up_proj.weight shape: [28672, 8192]
+717: model.layers.9.post_attention_layernorm.weight shape: [8192]
+718: model.layers.9.self_attn.k_proj.weight shape: [1024, 8192]
+719: model.layers.9.self_attn.o_proj.weight shape: [8192, 8192]
+720: model.layers.9.self_attn.q_proj.weight shape: [8192, 8192]
+721: model.layers.9.self_attn.v_proj.weight shape: [1024, 8192]
+722: model.norm.weight shape: [8192]
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_8b_ShapeTest.approved.txt b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_8b_ShapeTest.approved.txt
new file mode 100644
index 0000000000..887b49cfa6
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.Llama_3_1_8b_ShapeTest.approved.txt
@@ -0,0 +1,291 @@
+0: lm_head.weight shape: [128256, 4096]
+1: model.embed_tokens.weight shape: [128256, 4096]
+2: model.layers.0.input_layernorm.weight shape: [4096]
+3: model.layers.0.mlp.down_proj.weight shape: [4096, 14336]
+4: model.layers.0.mlp.gate_proj.weight shape: [14336, 4096]
+5: model.layers.0.mlp.up_proj.weight shape: [14336, 4096]
+6: model.layers.0.post_attention_layernorm.weight shape: [4096]
+7: model.layers.0.self_attn.k_proj.weight shape: [1024, 4096]
+8: model.layers.0.self_attn.o_proj.weight shape: [4096, 4096]
+9: model.layers.0.self_attn.q_proj.weight shape: [4096, 4096]
+10: model.layers.0.self_attn.v_proj.weight shape: [1024, 4096]
+11: model.layers.1.input_layernorm.weight shape: [4096]
+12: model.layers.1.mlp.down_proj.weight shape: [4096, 14336]
+13: model.layers.1.mlp.gate_proj.weight shape: [14336, 4096]
+14: model.layers.1.mlp.up_proj.weight shape: [14336, 4096]
+15: model.layers.1.post_attention_layernorm.weight shape: [4096]
+16: model.layers.1.self_attn.k_proj.weight shape: [1024, 4096]
+17: model.layers.1.self_attn.o_proj.weight shape: [4096, 4096]
+18: model.layers.1.self_attn.q_proj.weight shape: [4096, 4096]
+19: model.layers.1.self_attn.v_proj.weight shape: [1024, 4096]
+20: model.layers.10.input_layernorm.weight shape: [4096]
+21: model.layers.10.mlp.down_proj.weight shape: [4096, 14336]
+22: model.layers.10.mlp.gate_proj.weight shape: [14336, 4096]
+23: model.layers.10.mlp.up_proj.weight shape: [14336, 4096]
+24: model.layers.10.post_attention_layernorm.weight shape: [4096]
+25: model.layers.10.self_attn.k_proj.weight shape: [1024, 4096]
+26: model.layers.10.self_attn.o_proj.weight shape: [4096, 4096]
+27: model.layers.10.self_attn.q_proj.weight shape: [4096, 4096]
+28: model.layers.10.self_attn.v_proj.weight shape: [1024, 4096]
+29: model.layers.11.input_layernorm.weight shape: [4096]
+30: model.layers.11.mlp.down_proj.weight shape: [4096, 14336]
+31: model.layers.11.mlp.gate_proj.weight shape: [14336, 4096]
+32: model.layers.11.mlp.up_proj.weight shape: [14336, 4096]
+33: model.layers.11.post_attention_layernorm.weight shape: [4096]
+34: model.layers.11.self_attn.k_proj.weight shape: [1024, 4096]
+35: model.layers.11.self_attn.o_proj.weight shape: [4096, 4096]
+36: model.layers.11.self_attn.q_proj.weight shape: [4096, 4096]
+37: model.layers.11.self_attn.v_proj.weight shape: [1024, 4096]
+38: model.layers.12.input_layernorm.weight shape: [4096]
+39: model.layers.12.mlp.down_proj.weight shape: [4096, 14336]
+40: model.layers.12.mlp.gate_proj.weight shape: [14336, 4096]
+41: model.layers.12.mlp.up_proj.weight shape: [14336, 4096]
+42: model.layers.12.post_attention_layernorm.weight shape: [4096]
+43: model.layers.12.self_attn.k_proj.weight shape: [1024, 4096]
+44: model.layers.12.self_attn.o_proj.weight shape: [4096, 4096]
+45: model.layers.12.self_attn.q_proj.weight shape: [4096, 4096]
+46: model.layers.12.self_attn.v_proj.weight shape: [1024, 4096]
+47: model.layers.13.input_layernorm.weight shape: [4096]
+48: model.layers.13.mlp.down_proj.weight shape: [4096, 14336]
+49: model.layers.13.mlp.gate_proj.weight shape: [14336, 4096]
+50: model.layers.13.mlp.up_proj.weight shape: [14336, 4096]
+51: model.layers.13.post_attention_layernorm.weight shape: [4096]
+52: model.layers.13.self_attn.k_proj.weight shape: [1024, 4096]
+53: model.layers.13.self_attn.o_proj.weight shape: [4096, 4096]
+54: model.layers.13.self_attn.q_proj.weight shape: [4096, 4096]
+55: model.layers.13.self_attn.v_proj.weight shape: [1024, 4096]
+56: model.layers.14.input_layernorm.weight shape: [4096]
+57: model.layers.14.mlp.down_proj.weight shape: [4096, 14336]
+58: model.layers.14.mlp.gate_proj.weight shape: [14336, 4096]
+59: model.layers.14.mlp.up_proj.weight shape: [14336, 4096]
+60: model.layers.14.post_attention_layernorm.weight shape: [4096]
+61: model.layers.14.self_attn.k_proj.weight shape: [1024, 4096]
+62: model.layers.14.self_attn.o_proj.weight shape: [4096, 4096]
+63: model.layers.14.self_attn.q_proj.weight shape: [4096, 4096]
+64: model.layers.14.self_attn.v_proj.weight shape: [1024, 4096]
+65: model.layers.15.input_layernorm.weight shape: [4096]
+66: model.layers.15.mlp.down_proj.weight shape: [4096, 14336]
+67: model.layers.15.mlp.gate_proj.weight shape: [14336, 4096]
+68: model.layers.15.mlp.up_proj.weight shape: [14336, 4096]
+69: model.layers.15.post_attention_layernorm.weight shape: [4096]
+70: model.layers.15.self_attn.k_proj.weight shape: [1024, 4096]
+71: model.layers.15.self_attn.o_proj.weight shape: [4096, 4096]
+72: model.layers.15.self_attn.q_proj.weight shape: [4096, 4096]
+73: model.layers.15.self_attn.v_proj.weight shape: [1024, 4096]
+74: model.layers.16.input_layernorm.weight shape: [4096]
+75: model.layers.16.mlp.down_proj.weight shape: [4096, 14336]
+76: model.layers.16.mlp.gate_proj.weight shape: [14336, 4096]
+77: model.layers.16.mlp.up_proj.weight shape: [14336, 4096]
+78: model.layers.16.post_attention_layernorm.weight shape: [4096]
+79: model.layers.16.self_attn.k_proj.weight shape: [1024, 4096]
+80: model.layers.16.self_attn.o_proj.weight shape: [4096, 4096]
+81: model.layers.16.self_attn.q_proj.weight shape: [4096, 4096]
+82: model.layers.16.self_attn.v_proj.weight shape: [1024, 4096]
+83: model.layers.17.input_layernorm.weight shape: [4096]
+84: model.layers.17.mlp.down_proj.weight shape: [4096, 14336]
+85: model.layers.17.mlp.gate_proj.weight shape: [14336, 4096]
+86: model.layers.17.mlp.up_proj.weight shape: [14336, 4096]
+87: model.layers.17.post_attention_layernorm.weight shape: [4096]
+88: model.layers.17.self_attn.k_proj.weight shape: [1024, 4096]
+89: model.layers.17.self_attn.o_proj.weight shape: [4096, 4096]
+90: model.layers.17.self_attn.q_proj.weight shape: [4096, 4096]
+91: model.layers.17.self_attn.v_proj.weight shape: [1024, 4096]
+92: model.layers.18.input_layernorm.weight shape: [4096]
+93: model.layers.18.mlp.down_proj.weight shape: [4096, 14336]
+94: model.layers.18.mlp.gate_proj.weight shape: [14336, 4096]
+95: model.layers.18.mlp.up_proj.weight shape: [14336, 4096]
+96: model.layers.18.post_attention_layernorm.weight shape: [4096]
+97: model.layers.18.self_attn.k_proj.weight shape: [1024, 4096]
+98: model.layers.18.self_attn.o_proj.weight shape: [4096, 4096]
+99: model.layers.18.self_attn.q_proj.weight shape: [4096, 4096]
+100: model.layers.18.self_attn.v_proj.weight shape: [1024, 4096]
+101: model.layers.19.input_layernorm.weight shape: [4096]
+102: model.layers.19.mlp.down_proj.weight shape: [4096, 14336]
+103: model.layers.19.mlp.gate_proj.weight shape: [14336, 4096]
+104: model.layers.19.mlp.up_proj.weight shape: [14336, 4096]
+105: model.layers.19.post_attention_layernorm.weight shape: [4096]
+106: model.layers.19.self_attn.k_proj.weight shape: [1024, 4096]
+107: model.layers.19.self_attn.o_proj.weight shape: [4096, 4096]
+108: model.layers.19.self_attn.q_proj.weight shape: [4096, 4096]
+109: model.layers.19.self_attn.v_proj.weight shape: [1024, 4096]
+110: model.layers.2.input_layernorm.weight shape: [4096]
+111: model.layers.2.mlp.down_proj.weight shape: [4096, 14336]
+112: model.layers.2.mlp.gate_proj.weight shape: [14336, 4096]
+113: model.layers.2.mlp.up_proj.weight shape: [14336, 4096]
+114: model.layers.2.post_attention_layernorm.weight shape: [4096]
+115: model.layers.2.self_attn.k_proj.weight shape: [1024, 4096]
+116: model.layers.2.self_attn.o_proj.weight shape: [4096, 4096]
+117: model.layers.2.self_attn.q_proj.weight shape: [4096, 4096]
+118: model.layers.2.self_attn.v_proj.weight shape: [1024, 4096]
+119: model.layers.20.input_layernorm.weight shape: [4096]
+120: model.layers.20.mlp.down_proj.weight shape: [4096, 14336]
+121: model.layers.20.mlp.gate_proj.weight shape: [14336, 4096]
+122: model.layers.20.mlp.up_proj.weight shape: [14336, 4096]
+123: model.layers.20.post_attention_layernorm.weight shape: [4096]
+124: model.layers.20.self_attn.k_proj.weight shape: [1024, 4096]
+125: model.layers.20.self_attn.o_proj.weight shape: [4096, 4096]
+126: model.layers.20.self_attn.q_proj.weight shape: [4096, 4096]
+127: model.layers.20.self_attn.v_proj.weight shape: [1024, 4096]
+128: model.layers.21.input_layernorm.weight shape: [4096]
+129: model.layers.21.mlp.down_proj.weight shape: [4096, 14336]
+130: model.layers.21.mlp.gate_proj.weight shape: [14336, 4096]
+131: model.layers.21.mlp.up_proj.weight shape: [14336, 4096]
+132: model.layers.21.post_attention_layernorm.weight shape: [4096]
+133: model.layers.21.self_attn.k_proj.weight shape: [1024, 4096]
+134: model.layers.21.self_attn.o_proj.weight shape: [4096, 4096]
+135: model.layers.21.self_attn.q_proj.weight shape: [4096, 4096]
+136: model.layers.21.self_attn.v_proj.weight shape: [1024, 4096]
+137: model.layers.22.input_layernorm.weight shape: [4096]
+138: model.layers.22.mlp.down_proj.weight shape: [4096, 14336]
+139: model.layers.22.mlp.gate_proj.weight shape: [14336, 4096]
+140: model.layers.22.mlp.up_proj.weight shape: [14336, 4096]
+141: model.layers.22.post_attention_layernorm.weight shape: [4096]
+142: model.layers.22.self_attn.k_proj.weight shape: [1024, 4096]
+143: model.layers.22.self_attn.o_proj.weight shape: [4096, 4096]
+144: model.layers.22.self_attn.q_proj.weight shape: [4096, 4096]
+145: model.layers.22.self_attn.v_proj.weight shape: [1024, 4096]
+146: model.layers.23.input_layernorm.weight shape: [4096]
+147: model.layers.23.mlp.down_proj.weight shape: [4096, 14336]
+148: model.layers.23.mlp.gate_proj.weight shape: [14336, 4096]
+149: model.layers.23.mlp.up_proj.weight shape: [14336, 4096]
+150: model.layers.23.post_attention_layernorm.weight shape: [4096]
+151: model.layers.23.self_attn.k_proj.weight shape: [1024, 4096]
+152: model.layers.23.self_attn.o_proj.weight shape: [4096, 4096]
+153: model.layers.23.self_attn.q_proj.weight shape: [4096, 4096]
+154: model.layers.23.self_attn.v_proj.weight shape: [1024, 4096]
+155: model.layers.24.input_layernorm.weight shape: [4096]
+156: model.layers.24.mlp.down_proj.weight shape: [4096, 14336]
+157: model.layers.24.mlp.gate_proj.weight shape: [14336, 4096]
+158: model.layers.24.mlp.up_proj.weight shape: [14336, 4096]
+159: model.layers.24.post_attention_layernorm.weight shape: [4096]
+160: model.layers.24.self_attn.k_proj.weight shape: [1024, 4096]
+161: model.layers.24.self_attn.o_proj.weight shape: [4096, 4096]
+162: model.layers.24.self_attn.q_proj.weight shape: [4096, 4096]
+163: model.layers.24.self_attn.v_proj.weight shape: [1024, 4096]
+164: model.layers.25.input_layernorm.weight shape: [4096]
+165: model.layers.25.mlp.down_proj.weight shape: [4096, 14336]
+166: model.layers.25.mlp.gate_proj.weight shape: [14336, 4096]
+167: model.layers.25.mlp.up_proj.weight shape: [14336, 4096]
+168: model.layers.25.post_attention_layernorm.weight shape: [4096]
+169: model.layers.25.self_attn.k_proj.weight shape: [1024, 4096]
+170: model.layers.25.self_attn.o_proj.weight shape: [4096, 4096]
+171: model.layers.25.self_attn.q_proj.weight shape: [4096, 4096]
+172: model.layers.25.self_attn.v_proj.weight shape: [1024, 4096]
+173: model.layers.26.input_layernorm.weight shape: [4096]
+174: model.layers.26.mlp.down_proj.weight shape: [4096, 14336]
+175: model.layers.26.mlp.gate_proj.weight shape: [14336, 4096]
+176: model.layers.26.mlp.up_proj.weight shape: [14336, 4096]
+177: model.layers.26.post_attention_layernorm.weight shape: [4096]
+178: model.layers.26.self_attn.k_proj.weight shape: [1024, 4096]
+179: model.layers.26.self_attn.o_proj.weight shape: [4096, 4096]
+180: model.layers.26.self_attn.q_proj.weight shape: [4096, 4096]
+181: model.layers.26.self_attn.v_proj.weight shape: [1024, 4096]
+182: model.layers.27.input_layernorm.weight shape: [4096]
+183: model.layers.27.mlp.down_proj.weight shape: [4096, 14336]
+184: model.layers.27.mlp.gate_proj.weight shape: [14336, 4096]
+185: model.layers.27.mlp.up_proj.weight shape: [14336, 4096]
+186: model.layers.27.post_attention_layernorm.weight shape: [4096]
+187: model.layers.27.self_attn.k_proj.weight shape: [1024, 4096]
+188: model.layers.27.self_attn.o_proj.weight shape: [4096, 4096]
+189: model.layers.27.self_attn.q_proj.weight shape: [4096, 4096]
+190: model.layers.27.self_attn.v_proj.weight shape: [1024, 4096]
+191: model.layers.28.input_layernorm.weight shape: [4096]
+192: model.layers.28.mlp.down_proj.weight shape: [4096, 14336]
+193: model.layers.28.mlp.gate_proj.weight shape: [14336, 4096]
+194: model.layers.28.mlp.up_proj.weight shape: [14336, 4096]
+195: model.layers.28.post_attention_layernorm.weight shape: [4096]
+196: model.layers.28.self_attn.k_proj.weight shape: [1024, 4096]
+197: model.layers.28.self_attn.o_proj.weight shape: [4096, 4096]
+198: model.layers.28.self_attn.q_proj.weight shape: [4096, 4096]
+199: model.layers.28.self_attn.v_proj.weight shape: [1024, 4096]
+200: model.layers.29.input_layernorm.weight shape: [4096]
+201: model.layers.29.mlp.down_proj.weight shape: [4096, 14336]
+202: model.layers.29.mlp.gate_proj.weight shape: [14336, 4096]
+203: model.layers.29.mlp.up_proj.weight shape: [14336, 4096]
+204: model.layers.29.post_attention_layernorm.weight shape: [4096]
+205: model.layers.29.self_attn.k_proj.weight shape: [1024, 4096]
+206: model.layers.29.self_attn.o_proj.weight shape: [4096, 4096]
+207: model.layers.29.self_attn.q_proj.weight shape: [4096, 4096]
+208: model.layers.29.self_attn.v_proj.weight shape: [1024, 4096]
+209: model.layers.3.input_layernorm.weight shape: [4096]
+210: model.layers.3.mlp.down_proj.weight shape: [4096, 14336]
+211: model.layers.3.mlp.gate_proj.weight shape: [14336, 4096]
+212: model.layers.3.mlp.up_proj.weight shape: [14336, 4096]
+213: model.layers.3.post_attention_layernorm.weight shape: [4096]
+214: model.layers.3.self_attn.k_proj.weight shape: [1024, 4096]
+215: model.layers.3.self_attn.o_proj.weight shape: [4096, 4096]
+216: model.layers.3.self_attn.q_proj.weight shape: [4096, 4096]
+217: model.layers.3.self_attn.v_proj.weight shape: [1024, 4096]
+218: model.layers.30.input_layernorm.weight shape: [4096]
+219: model.layers.30.mlp.down_proj.weight shape: [4096, 14336]
+220: model.layers.30.mlp.gate_proj.weight shape: [14336, 4096]
+221: model.layers.30.mlp.up_proj.weight shape: [14336, 4096]
+222: model.layers.30.post_attention_layernorm.weight shape: [4096]
+223: model.layers.30.self_attn.k_proj.weight shape: [1024, 4096]
+224: model.layers.30.self_attn.o_proj.weight shape: [4096, 4096]
+225: model.layers.30.self_attn.q_proj.weight shape: [4096, 4096]
+226: model.layers.30.self_attn.v_proj.weight shape: [1024, 4096]
+227: model.layers.31.input_layernorm.weight shape: [4096]
+228: model.layers.31.mlp.down_proj.weight shape: [4096, 14336]
+229: model.layers.31.mlp.gate_proj.weight shape: [14336, 4096]
+230: model.layers.31.mlp.up_proj.weight shape: [14336, 4096]
+231: model.layers.31.post_attention_layernorm.weight shape: [4096]
+232: model.layers.31.self_attn.k_proj.weight shape: [1024, 4096]
+233: model.layers.31.self_attn.o_proj.weight shape: [4096, 4096]
+234: model.layers.31.self_attn.q_proj.weight shape: [4096, 4096]
+235: model.layers.31.self_attn.v_proj.weight shape: [1024, 4096]
+236: model.layers.4.input_layernorm.weight shape: [4096]
+237: model.layers.4.mlp.down_proj.weight shape: [4096, 14336]
+238: model.layers.4.mlp.gate_proj.weight shape: [14336, 4096]
+239: model.layers.4.mlp.up_proj.weight shape: [14336, 4096]
+240: model.layers.4.post_attention_layernorm.weight shape: [4096]
+241: model.layers.4.self_attn.k_proj.weight shape: [1024, 4096]
+242: model.layers.4.self_attn.o_proj.weight shape: [4096, 4096]
+243: model.layers.4.self_attn.q_proj.weight shape: [4096, 4096]
+244: model.layers.4.self_attn.v_proj.weight shape: [1024, 4096]
+245: model.layers.5.input_layernorm.weight shape: [4096]
+246: model.layers.5.mlp.down_proj.weight shape: [4096, 14336]
+247: model.layers.5.mlp.gate_proj.weight shape: [14336, 4096]
+248: model.layers.5.mlp.up_proj.weight shape: [14336, 4096]
+249: model.layers.5.post_attention_layernorm.weight shape: [4096]
+250: model.layers.5.self_attn.k_proj.weight shape: [1024, 4096]
+251: model.layers.5.self_attn.o_proj.weight shape: [4096, 4096]
+252: model.layers.5.self_attn.q_proj.weight shape: [4096, 4096]
+253: model.layers.5.self_attn.v_proj.weight shape: [1024, 4096]
+254: model.layers.6.input_layernorm.weight shape: [4096]
+255: model.layers.6.mlp.down_proj.weight shape: [4096, 14336]
+256: model.layers.6.mlp.gate_proj.weight shape: [14336, 4096]
+257: model.layers.6.mlp.up_proj.weight shape: [14336, 4096]
+258: model.layers.6.post_attention_layernorm.weight shape: [4096]
+259: model.layers.6.self_attn.k_proj.weight shape: [1024, 4096]
+260: model.layers.6.self_attn.o_proj.weight shape: [4096, 4096]
+261: model.layers.6.self_attn.q_proj.weight shape: [4096, 4096]
+262: model.layers.6.self_attn.v_proj.weight shape: [1024, 4096]
+263: model.layers.7.input_layernorm.weight shape: [4096]
+264: model.layers.7.mlp.down_proj.weight shape: [4096, 14336]
+265: model.layers.7.mlp.gate_proj.weight shape: [14336, 4096]
+266: model.layers.7.mlp.up_proj.weight shape: [14336, 4096]
+267: model.layers.7.post_attention_layernorm.weight shape: [4096]
+268: model.layers.7.self_attn.k_proj.weight shape: [1024, 4096]
+269: model.layers.7.self_attn.o_proj.weight shape: [4096, 4096]
+270: model.layers.7.self_attn.q_proj.weight shape: [4096, 4096]
+271: model.layers.7.self_attn.v_proj.weight shape: [1024, 4096]
+272: model.layers.8.input_layernorm.weight shape: [4096]
+273: model.layers.8.mlp.down_proj.weight shape: [4096, 14336]
+274: model.layers.8.mlp.gate_proj.weight shape: [14336, 4096]
+275: model.layers.8.mlp.up_proj.weight shape: [14336, 4096]
+276: model.layers.8.post_attention_layernorm.weight shape: [4096]
+277: model.layers.8.self_attn.k_proj.weight shape: [1024, 4096]
+278: model.layers.8.self_attn.o_proj.weight shape: [4096, 4096]
+279: model.layers.8.self_attn.q_proj.weight shape: [4096, 4096]
+280: model.layers.8.self_attn.v_proj.weight shape: [1024, 4096]
+281: model.layers.9.input_layernorm.weight shape: [4096]
+282: model.layers.9.mlp.down_proj.weight shape: [4096, 14336]
+283: model.layers.9.mlp.gate_proj.weight shape: [14336, 4096]
+284: model.layers.9.mlp.up_proj.weight shape: [14336, 4096]
+285: model.layers.9.post_attention_layernorm.weight shape: [4096]
+286: model.layers.9.self_attn.k_proj.weight shape: [1024, 4096]
+287: model.layers.9.self_attn.o_proj.weight shape: [4096, 4096]
+288: model.layers.9.self_attn.q_proj.weight shape: [4096, 4096]
+289: model.layers.9.self_attn.v_proj.weight shape: [1024, 4096]
+290: model.norm.weight shape: [4096]
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.TokenizerTest.approved.txt b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.TokenizerTest.approved.txt
new file mode 100644
index 0000000000..fc0568084b
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Approvals/LLaMA3_1Tests.TokenizerTest.approved.txt
@@ -0,0 +1,8 @@
+Can you provide ways to eat combinations of bananas and dragonfruits?
+6854, 499, 3493, 5627, 311, 8343, 28559, 315, 68442, 323, 26161, 1658, 12059, 30
+Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.
+40914, 0, 5810, 527, 1063, 5627, 311, 8343, 68442, 323, 26161, 1658, 12059, 3871, 25, 220, 16, 13, 76924, 323, 26161, 36698, 11113, 648, 25, 55248, 68442, 323, 26161, 1658, 12059, 3871, 449, 1063, 14403, 323, 26828, 13, 220, 17, 13, 76924, 323, 26161, 36698, 33566, 25, 19771, 48715, 68442, 323, 26161, 1658, 12059, 3871, 449, 1063, 30564, 23661, 323, 26828, 13
+What about solving an 2x + 3 = 7 equation?
+3923, 922, 22581, 459, 220, 17, 87, 489, 220, 18, 284, 220, 22, 24524, 30
+<|begin_of_text|>Hello World<|end_of_text|>
+128000, 9906, 4435, 128001
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.TokenizerTest.received.txt b/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.TokenizerTest.received.txt
new file mode 100644
index 0000000000..9bb3220214
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.TokenizerTest.received.txt
@@ -0,0 +1,6 @@
+Can you provide ways to eat combinations of bananas and dragonfruits?
+6854, 499, 3493, 5627, 311, 8343, 28559, 315, 68442, 323, 26161, 1658, 12059, 30
+Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.
+40914, 0, 5810, 527, 1063, 5627, 311, 8343, 68442, 323, 26161, 1658, 12059, 3871, 25, 220, 16, 13, 76924, 323, 26161, 36698, 11113, 648, 25, 55248, 68442, 323, 26161, 1658, 12059, 3871, 449, 1063, 14403, 323, 26828, 13, 220, 17, 13, 76924, 323, 26161, 36698, 33566, 25, 19771, 48715, 68442, 323, 26161, 1658, 12059, 3871, 449, 1063, 30564, 23661, 323, 26828, 13
+What about solving an 2x + 3 = 7 equation?
+3923, 922, 22581, 459, 220, 17, 87, 489, 220, 18, 284, 220, 22, 24524, 30
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.cs b/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.cs
new file mode 100644
index 0000000000..7d97150f7b
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/LLaMA3_1Tests.cs
@@ -0,0 +1,125 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Text;
+using ApprovalTests;
+using ApprovalTests.Namers;
+using ApprovalTests.Reporters;
+using AutoGen.Core;
+using Microsoft.ML.GenAI.Core.Extension;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+using TorchSharp;
+using Xunit;
+
+namespace Microsoft.ML.GenAI.LLaMA.Tests;
+
+[Collection("NoParallelization")]
+public class LLaMA3_1Tests
+{
+ public LLaMA3_1Tests()
+ {
+ if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
+ {
+ Approvals.UseAssemblyLocationForApprovedFiles();
+ }
+
+ torch.set_default_device("meta");
+ }
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void Llama_3_1_8b_ShapeTest()
+ {
+ var model = new LlamaForCausalLM(LlamaConfig.Llama3_1_8B_Instruct, "meta");
+ var stateDictStr = model.PeekShape();
+ Approvals.Verify(stateDictStr);
+ }
+
+ [WindowsOnlyFact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void Llama_3_1_70b_ShapeTest()
+ {
+ var model = new LlamaForCausalLM(LlamaConfig.Llama3_1_70B_Instruct, "meta");
+ var stateDictStr = model.PeekShape();
+ Approvals.Verify(stateDictStr);
+ }
+
+ [WindowsOnlyFact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void Llama_3_1_405b_ShapeTest()
+ {
+ var model = new LlamaForCausalLM(LlamaConfig.Llama3_1_405B_Instruct, "meta");
+ var stateDictStr = model.PeekShape();
+ Approvals.Verify(stateDictStr);
+ }
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void TokenizerTest()
+ {
+ var modelWeightFolder = Path.Join("Llama-3.1");
+ var tokenizer = LlamaTokenizerHelper.FromPretrained(modelWeightFolder);
+
+ var messages = new string[]
+ {
+ "Can you provide ways to eat combinations of bananas and dragonfruits?",
+ "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey.",
+ "What about solving an 2x + 3 = 7 equation?",
+ """
+ <|begin_of_text|>Hello World<|end_of_text|>
+ """
+ };
+
+ var sb = new StringBuilder();
+ foreach (var message in messages)
+ {
+ var tokenizeIds = tokenizer.EncodeToIds(message, true, false);
+ var decodeToString = tokenizer.Decode(tokenizeIds);
+ sb.AppendLine(decodeToString);
+ var tokenizedStr = string.Join(", ", tokenizeIds.Select(x => x.ToString()));
+
+ sb.AppendLine(tokenizedStr);
+ }
+ Approvals.Verify(sb.ToString());
+ }
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void ItBuildChatTemplateFromAutoGenChatHistory()
+ {
+ var chatHistory = new List
+ {
+ new TextMessage(Role.System, "You are a helpful AI assistant."),
+ new TextMessage(Role.User, "Hello?"),
+ new TextMessage(Role.Assistant, "World!"),
+ };
+
+ var prompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatHistory);
+
+ Approvals.Verify(prompt);
+ }
+
+ [Fact]
+ [UseReporter(typeof(DiffReporter))]
+ [UseApprovalSubdirectory("Approvals")]
+ public void ItBuildChatTemplateFromSemanticKernelChatHistory()
+ {
+ var chatHistory = new ChatHistory
+ {
+ new ChatMessageContent(AuthorRole.System, "You are a helpful AI assistant."),
+ new ChatMessageContent(AuthorRole.User, "Hello?"),
+ new ChatMessageContent(AuthorRole.Assistant, "World!"),
+ };
+
+ var prompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatHistory);
+
+ Approvals.Verify(prompt);
+ }
+}
diff --git a/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj b/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj
new file mode 100644
index 0000000000..643c1d91b2
--- /dev/null
+++ b/test/Microsoft.ML.GenAI.LLaMA.Tests/Microsoft.ML.GenAI.LLaMA.Tests.csproj
@@ -0,0 +1,44 @@
+
+
+
+ net6.0
+ enable
+ $(NoWarn);MSML_ExtendBaseTestClass
+ enable
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ PreserveNewest
+
+
+
+
+
+
+
+
+
+
+