Skip to content

Commit 7c937bf

Browse files
[GenAI] Add generateEmbedding API to CausalLMPipeline (#7227)
* add embedding * add frompretrain api to phi3 model * fix bug * Update CausalLMPipeline.cs
1 parent 1d1cc99 commit 7c937bf

File tree

6 files changed

+93
-111
lines changed

6 files changed

+93
-111
lines changed

docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/AutoGenSample.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using TorchSharp;
1010
using Microsoft.ML.GenAI.Core;
1111
using Microsoft.ML.GenAI.Core.Extension;
12+
using Microsoft.ML.Tokenizers;
1213

1314
namespace Microsoft.ML.GenAI.Samples.Phi3Mini;
1415

@@ -26,12 +27,15 @@ public static async Task RunAsync()
2627
torch.manual_seed(1);
2728
torch.set_default_dtype(defaultType);
2829
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
29-
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device, quantizeToInt8: false);
30+
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
31+
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
32+
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
33+
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
34+
var question = @"write a C# program to calculate the factorial of a number";
3035

3136
// agent
3237
var agent = new Phi3Agent(pipeline, "assistant")
3338
.RegisterPrintMessage();
34-
var question = @"write a C# program to calculate the factorial of a number";
3539

3640
// chat with the assistant
3741
await agent.SendAsync(question);

docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/SemanticKernelSample.cs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
using Microsoft.ML.GenAI.Phi.Extension;
1+
using Microsoft.ML.GenAI.Core;
2+
using Microsoft.ML.GenAI.Phi;
3+
using Microsoft.ML.GenAI.Phi.Extension;
4+
using Microsoft.ML.Tokenizers;
25
using Microsoft.SemanticKernel;
36
using Microsoft.SemanticKernel.ChatCompletion;
47
using TorchSharp;
@@ -20,8 +23,10 @@ public static async Task RunChatCompletionSample()
2023
torch.manual_seed(1);
2124
torch.set_default_dtype(defaultType);
2225
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
23-
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device: device);
24-
26+
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
27+
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
28+
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
29+
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
2530

2631
var kernel = Kernel.CreateBuilder()
2732
.AddGenAIChatCompletion(pipeline)
@@ -49,8 +54,10 @@ public static async Task RunTextGenerationSample()
4954
torch.manual_seed(1);
5055
torch.set_default_dtype(defaultType);
5156
var weightFolder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
52-
var pipeline = Utils.LoadPhi3Mini4KFromFolder(weightFolder, device);
53-
57+
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
58+
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
59+
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
60+
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
5461

5562
var kernel = Kernel.CreateBuilder()
5663
.AddGenAITextGeneration(pipeline)

docs/samples/Microsoft.ML.GenAI.Samples/Phi3Mini/Utils.cs

Lines changed: 0 additions & 103 deletions
This file was deleted.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
// See https://aka.ms/new-console-template for more information
22
using Microsoft.ML.GenAI.Samples.Phi3Mini;
33

4-
await SemanticKernelSample.RunChatCompletionSample();
4+
await AutoGenSample.RunAsync();

src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ string Generate(
3232
float topP = CausalLMPipeline.Defaults.TopP,
3333
string[]? stopSequences = CausalLMPipeline.Defaults.StopSequence);
3434

35+
/// <summary>
36+
/// Generate the embedding(last hidden state of the last token) for the prompt. The embedding is normalized by L2 norm.
37+
/// </summary>
38+
float[] GenerateEmbeddingFromLastTokenPool(string prompt);
39+
3540
IEnumerable<string> GenerateStreaming(
3641
string prompt,
3742
int maxLen = CausalLMPipeline.Defaults.MaxLen,
@@ -281,4 +286,23 @@ protected torch.Tensor SampleTopP(torch.Tensor logits, float topP)
281286
nextToken = torch.gather(probsIndex, dim: -1, index: nextToken);
282287
return nextToken;
283288
}
289+
290+
public float[] GenerateEmbeddingFromLastTokenPool(string prompt)
291+
{
292+
using var scope = NewDisposeScope();
293+
using var noGrad = torch.no_grad();
294+
var inputIds = this.Tokenizer.EncodeToIds(prompt);
295+
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: this.Device).unsqueeze(0);
296+
var attentionMask = torch.ones_like(inputTensor, device: this.Device);
297+
var input = new CausalLMModelInput(inputTensor, attentionMask, pastKeyValuesLength: 0);
298+
var output = this.Model.forward(input);
299+
var lastTokenHiddenState = output.LastHiddenState[0, ^1];
300+
301+
// shape of lastTokenHiddenState: [hidden_size]
302+
// L2 norm
303+
var norm = lastTokenHiddenState.norm();
304+
var normalized = lastTokenHiddenState / norm;
305+
306+
return normalized.to_type(ScalarType.Float32).data<float>().ToArray();
307+
}
284308
}

src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Text.Json;
1010
using System.Threading.Tasks;
1111
using Microsoft.ML.GenAI.Core;
12+
using Microsoft.ML.GenAI.Core.Extension;
1213
using Microsoft.ML.GenAI.Phi.Module;
1314
using TorchSharp;
1415
using TorchSharp.Modules;
@@ -66,6 +67,55 @@ public static Phi3ForCasualLM FromPretrained(
6667
return phi;
6768
}
6869

70+
public static Phi3ForCasualLM FromPretrained(
71+
string modelFolder,
72+
string configName = "config.json",
73+
string checkPointName = "model.safetensors.index.json",
74+
bool quantizeToInt8 = false,
75+
bool quantizeToInt4 = false,
76+
int layersOnTargetDevice = -1,
77+
ScalarType torchDtype = ScalarType.BFloat16,
78+
string targetDevice = "cuda")
79+
{
80+
if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
81+
{
82+
return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
83+
}
84+
85+
var originalDefaultDevice = torch.get_default_device();
86+
torch.set_default_device("meta");
87+
var config = Path.Join(modelFolder, configName);
88+
var modelConfig = JsonSerializer.Deserialize<Phi3Config>(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config));
89+
modelConfig.DType = torchDtype;
90+
var model = new Phi3ForCasualLM(modelConfig);
91+
92+
if (quantizeToInt8)
93+
{
94+
model.ToInt8QuantizeModule();
95+
}
96+
else if (quantizeToInt4)
97+
{
98+
model.ToInt4QuantizeModule();
99+
}
100+
101+
var deviceMap = model.InferDeviceMapForEachLayer(
102+
[
103+
KeyValuePair.Create(targetDevice, layersOnTargetDevice),
104+
KeyValuePair.Create("cpu", -1)
105+
]);
106+
107+
torch.set_default_device("cpu");
108+
model = new Phi3ForCasualLM(modelConfig);
109+
110+
model.LoadSafeTensors(modelFolder, checkPointName);
111+
112+
model = model.ToDynamicLoadingModel(deviceMap, targetDevice);
113+
114+
torch.set_default_device(originalDefaultDevice);
115+
116+
return model;
117+
}
118+
69119
public void LoadSafeTensors(string modelFolder, string checkPointName = "model.safetensors.index.json")
70120
{
71121
this.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, useTqdm: false);

0 commit comments

Comments
 (0)