Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ protected Object checkPrivateKey(byte[] sk) throws InvalidKeyException {
/*
Main internal algorithms from Section 6 of specification
*/
protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d, byte[] kem_z) {
protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d_z) {
MessageDigest mlKemH;
try {
mlKemH = MessageDigest.getInstance(HASH_H_NAME);
Expand All @@ -508,7 +508,8 @@ protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d, byte[] kem_z) {
}

//Generate K-PKE keys
var kPkeKeyPair = generateK_PkeKeyPair(kem_d);
//The 1st 32-byte `d` is used in K-PKE key pair generation
var kPkeKeyPair = generateK_PkeKeyPair(kem_d_z);
//encaps key = kPke encryption key
byte[] encapsKey = kPkeKeyPair.publicKey.keyBytes;

Expand All @@ -527,14 +528,21 @@ protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d, byte[] kem_z) {
// This should never happen.
throw new RuntimeException(e);
}
System.arraycopy(kem_z, 0, decapsKey,
// The 2nd 32-byte `z` is copied into decapsKey
System.arraycopy(kem_d_z, 32, decapsKey,
kPkePrivateKey.length + encapsKey.length + 32, 32);

return new ML_KEM_KeyPair(
new ML_KEM_EncapsulationKey(encapsKey),
new ML_KEM_DecapsulationKey(decapsKey));
}

public byte[] privKeyToPubKey(byte[] decapsKey) {
int pkLen = (mlKem_k * ML_KEM_N * 12) / 8 + 32 /* rho */;
int skLen = (mlKem_k * ML_KEM_N * 12) / 8;
return Arrays.copyOfRange(decapsKey, skLen, skLen + pkLen);
}

protected ML_KEM_EncapsulateResult encapsulate(
ML_KEM_EncapsulationKey encapsulationKey, byte[] randomMessage) {
MessageDigest mlKemH;
Expand Down Expand Up @@ -648,10 +656,12 @@ private K_PKE_KeyPair generateK_PkeKeyPair(byte[] seed) {
throw new RuntimeException(e);
}

mlKemG.update(seed);
// Note: only the 1st 32-byte in the seed is used
mlKemG.update(seed, 0, 32);
mlKemG.update((byte)mlKem_k);

var rhoSigma = mlKemG.digest();
mlKemG.reset();
var rho = Arrays.copyOfRange(rhoSigma, 0, 32);
var sigma = Arrays.copyOfRange(rhoSigma, 32, 64);
Arrays.fill(rhoSigma, (byte)0);
Expand Down
110 changes: 80 additions & 30 deletions src/java.base/share/classes/com/sun/crypto/provider/ML_KEM_Impls.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand All @@ -26,9 +26,12 @@
package com.sun.crypto.provider;

import sun.security.jca.JCAUtil;
import sun.security.pkcs.NamedPKCS8Key;
import sun.security.provider.NamedKEM;
import sun.security.provider.NamedKeyFactory;
import sun.security.provider.NamedKeyPairGenerator;
import sun.security.util.KeyChoices;
import sun.security.x509.NamedX509Key;

import java.security.*;
import java.util.Arrays;
Expand All @@ -37,6 +40,20 @@

public final class ML_KEM_Impls {

private static final int SEED_LEN = 64;

public static byte[] seedToExpanded(String pname, byte[] seed) {
return new ML_KEM(pname).generateKemKeyPair(seed)
.decapsulationKey()
.keyBytes();
}

public static NamedX509Key privKeyToPubKey(NamedPKCS8Key npk) {
return new NamedX509Key(npk.getAlgorithm(),
npk.getParams().getName(),
new ML_KEM(npk.getParams().getName()).privKeyToPubKey(npk.getExpanded()));
}

public sealed static class KPG
extends NamedKeyPairGenerator permits KPG2, KPG3, KPG5 {

Expand All @@ -50,25 +67,27 @@ protected KPG(String pname) {
}

@Override
protected byte[][] implGenerateKeyPair(String name, SecureRandom random) {
byte[] seed = new byte[32];
protected byte[][] implGenerateKeyPair(String pname, SecureRandom random) {
byte[] seed = new byte[SEED_LEN];
var r = random != null ? random : JCAUtil.getDefSecureRandom();
r.nextBytes(seed);
byte[] z = new byte[32];
r.nextBytes(z);

ML_KEM mlKem = new ML_KEM(name);
ML_KEM mlKem = new ML_KEM(pname);
ML_KEM.ML_KEM_KeyPair kp;
kp = mlKem.generateKemKeyPair(seed);
var expanded = kp.decapsulationKey().keyBytes();

try {
kp = mlKem.generateKemKeyPair(seed, z);
return new byte[][]{
kp.encapsulationKey().keyBytes(),
KeyChoices.writeToChoice(
KeyChoices.getPreferred("mlkem"),
seed, expanded),
expanded
};
} finally {
Arrays.fill(seed, (byte)0);
Arrays.fill(z, (byte)0);
Arrays.fill(seed, (byte) 0);
}
return new byte[][] {
kp.encapsulationKey().keyBytes(),
kp.decapsulationKey().keyBytes()
};
}
}

Expand All @@ -94,8 +113,39 @@ public sealed static class KF extends NamedKeyFactory permits KF2, KF3, KF5 {
public KF() {
super("ML-KEM", "ML-KEM-512", "ML-KEM-768", "ML-KEM-1024");
}
public KF(String name) {
super("ML-KEM", name);
public KF(String pname) {
super("ML-KEM", pname);
}

@Override
protected byte[] implExpand(String pname, byte[] input)
throws InvalidKeyException {
return KeyChoices.choiceToExpanded(pname, SEED_LEN, input,
ML_KEM_Impls::seedToExpanded);
}

@Override
protected Key engineTranslateKey(Key key) throws InvalidKeyException {
var nk = toNamedKey(key);
if (nk instanceof NamedPKCS8Key npk) {
var type = KeyChoices.getPreferred("mlkem");
if (KeyChoices.typeOfChoice(npk.getRawBytes()) != type) {
var encoding = KeyChoices.choiceToChoice(
type,
npk.getParams().getName(),
SEED_LEN, npk.getRawBytes(),
ML_KEM_Impls::seedToExpanded);
nk = NamedPKCS8Key.internalCreate(
npk.getAlgorithm(),
npk.getParams().getName(),
encoding,
npk.getExpanded().clone());
if (npk != key) { // npk is neither input or output
npk.destroy();
}
}
}
return nk;
}
}

Expand All @@ -121,15 +171,15 @@ public sealed static class K extends NamedKEM permits K2, K3, K5 {
private static final int SEED_SIZE = 32;

@Override
protected byte[][] implEncapsulate(String name, byte[] encapsulationKey,
protected byte[][] implEncapsulate(String pname, byte[] encapsulationKey,
Object ek, SecureRandom secureRandom) {

byte[] randomBytes = new byte[SEED_SIZE];
var r = secureRandom != null ? secureRandom : JCAUtil.getDefSecureRandom();
r.nextBytes(randomBytes);

ML_KEM mlKem = new ML_KEM(name);
ML_KEM.ML_KEM_EncapsulateResult mlKemEncapsulateResult = null;
ML_KEM mlKem = new ML_KEM(pname);
ML_KEM.ML_KEM_EncapsulateResult mlKemEncapsulateResult;
try {
mlKemEncapsulateResult = mlKem.encapsulate(
new ML_KEM.ML_KEM_EncapsulationKey(
Expand All @@ -145,49 +195,49 @@ protected byte[][] implEncapsulate(String name, byte[] encapsulationKey,
}

@Override
protected byte[] implDecapsulate(String name, byte[] decapsulationKey,
protected byte[] implDecapsulate(String pname, byte[] decapsulationKey,
Object dk, byte[] cipherText)
throws DecapsulateException {

ML_KEM mlKem = new ML_KEM(name);
ML_KEM mlKem = new ML_KEM(pname);
var kpkeCipherText = new ML_KEM.K_PKE_CipherText(cipherText);
return mlKem.decapsulate(new ML_KEM.ML_KEM_DecapsulationKey(
decapsulationKey), kpkeCipherText);
}

@Override
protected int implSecretSize(String name) {
protected int implSecretSize(String pname) {
return ML_KEM.SECRET_SIZE;
}

@Override
protected int implEncapsulationSize(String name) {
ML_KEM mlKem = new ML_KEM(name);
protected int implEncapsulationSize(String pname) {
ML_KEM mlKem = new ML_KEM(pname);
return mlKem.getEncapsulationSize();
}

@Override
protected Object implCheckPublicKey(String name, byte[] pk)
protected Object implCheckPublicKey(String pname, byte[] pk)
throws InvalidKeyException {

ML_KEM mlKem = new ML_KEM(name);
ML_KEM mlKem = new ML_KEM(pname);
return mlKem.checkPublicKey(pk);
}

@Override
protected Object implCheckPrivateKey(String name, byte[] sk)
protected Object implCheckPrivateKey(String pname, byte[] sk)
throws InvalidKeyException {

ML_KEM mlKem = new ML_KEM(name);
ML_KEM mlKem = new ML_KEM(pname);
return mlKem.checkPrivateKey(sk);
}

public K() {
super("ML-KEM", "ML-KEM-512", "ML-KEM-768", "ML-KEM-1024");
super("ML-KEM", new KF());
}

public K(String name) {
super("ML-KEM", name);
public K(String pname) {
super("ML-KEM", new KF(pname));
}
}

Expand Down
Loading