1616
1717package org .springframework .security .saml2 .provider .service .registration ;
1818
19- import java .io .IOException ;
20- import java .sql .PreparedStatement ;
2119import java .sql .ResultSet ;
2220import java .sql .SQLException ;
2321import java .sql .Types ;
24- import java .util .ArrayList ;
2522import java .util .Collection ;
2623import java .util .Iterator ;
2724import java .util .List ;
28- import java .util .function .Function ;
25+ import java .util .function .Consumer ;
2926
3027import org .apache .commons .logging .Log ;
3128import org .apache .commons .logging .LogFactory ;
32- import org .slf4j .Logger ;
33- import org .slf4j .LoggerFactory ;
3429import org .springframework .core .log .LogMessage ;
3530import org .springframework .core .serializer .DefaultDeserializer ;
36- import org .springframework .core .serializer .DefaultSerializer ;
3731import org .springframework .core .serializer .Deserializer ;
38- import org .springframework .core .serializer .Serializer ;
3932import org .springframework .jdbc .core .ArgumentPreparedStatementSetter ;
4033import org .springframework .jdbc .core .JdbcOperations ;
4134import org .springframework .jdbc .core .PreparedStatementSetter ;
4437import org .springframework .security .saml2 .core .Saml2X509Credential ;
4538import org .springframework .security .saml2 .provider .service .registration .RelyingPartyRegistration .AssertingPartyDetails ;
4639import org .springframework .util .Assert ;
40+ import org .springframework .util .StringUtils ;
4741
4842/**
4943 * A JDBC implementation of {@link AssertingPartyMetadataRepository}.
@@ -58,13 +52,9 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
5852 private RowMapper <AssertingPartyMetadata > assertingPartyMetadataRowMapper =
5953 new AssertingPartyMetadataRowMapper (ResultSet ::getBytes );
6054
61- private Function <AssertingPartyMetadata , List <SqlParameterValue >> assertingPartyMetadataParametersMapper =
62- new AssertingPartyMetadataParametersMapper ();
63-
64- private final SetBytes setBytes = PreparedStatement ::setBytes ;
65-
6655 // @formatter:off
6756 static final String COLUMN_NAMES = "entity_id, "
57+ + "metadata_uri, "
6858 + "singlesignon_url, "
6959 + "singlesignon_binding, "
7060 + "singlesignon_sign_request, "
@@ -87,26 +77,6 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
8777
8878 private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES
8979 + " FROM " + TABLE_NAME ;
90-
91- private static final String SAVE_SQL = "INSERT INTO " + TABLE_NAME + " ("
92- + COLUMN_NAMES
93- + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ;
94- // @formatter:on
95-
96- private static final String DELETE_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + ENTITY_ID_FILTER ;
97-
98- // @formatter:off
99- private static final String UPDATE_SQL = "UPDATE " + TABLE_NAME
100- + " SET singlesignon_url = ?, " +
101- "singlesignon_binding = ?, " +
102- "singlesignon_sign_request = ?, " +
103- "signing_algorithms = ?, " +
104- "verification_credentials = ?, " +
105- "encryption_credentials = ?, " +
106- "singlelogout_url = ? ," +
107- "singlelogout_response_url = ?, " +
108- "singlelogout_binding = ?"
109- + " WHERE " + ENTITY_ID_FILTER ;
11080 // @formatter:on
11181
11282 /**
@@ -134,41 +104,6 @@ public void setAssertingPartyMetadataRowMapper(
134104 this .assertingPartyMetadataRowMapper = assertingPartyMetadataRowMapper ;
135105 }
136106
137- public void setAssertingPartyMetadataParametersMapper (Function <AssertingPartyMetadata , List <SqlParameterValue >> assertingPartyMetadataParametersMapper ) {
138- Assert .notNull (assertingPartyMetadataParametersMapper , "assertingPartyMetadataParametersMapper cannot be null" );
139- this .assertingPartyMetadataParametersMapper = assertingPartyMetadataParametersMapper ;
140- }
141-
142- public void save (AssertingPartyMetadata metadata ) {
143- Assert .notNull (metadata , "metadata cannot be null" );
144- int rows = update (metadata );
145- if (rows == 0 ) {
146- insert (metadata );
147- }
148- }
149-
150- private void insert (AssertingPartyMetadata metadata ) {
151- List <SqlParameterValue > parameters = this .assertingPartyMetadataParametersMapper .apply (metadata );
152- PreparedStatementSetter pss = new BlobArgumentPreparedStatementSetter (this .setBytes , parameters .toArray ());
153- this .jdbcOperations .update (SAVE_SQL , pss );
154- }
155-
156- private int update (AssertingPartyMetadata metadata ) {
157- List <SqlParameterValue > parameters = this .assertingPartyMetadataParametersMapper .apply (metadata );
158- SqlParameterValue credentialId = parameters .remove (0 );
159- parameters .add (credentialId );
160- PreparedStatementSetter pss = new BlobArgumentPreparedStatementSetter (this .setBytes , parameters .toArray ());
161- return this .jdbcOperations .update (UPDATE_SQL , pss );
162- }
163-
164- public void delete (String entityId ) {
165- Assert .notNull (entityId , "entityId cannot be null" );
166- SqlParameterValue [] parameters = new SqlParameterValue []{
167- new SqlParameterValue (Types .VARCHAR , entityId ),};
168- PreparedStatementSetter pss = new ArgumentPreparedStatementSetter (parameters );
169- this .jdbcOperations .update (DELETE_SQL , pss );
170- }
171-
172107 @ Override
173108 public AssertingPartyMetadata findByEntityId (String entityId ) {
174109 Assert .hasText (entityId , "entityId cannot be empty" );
@@ -187,75 +122,6 @@ public Iterator<AssertingPartyMetadata> iterator() {
187122 return result .iterator ();
188123 }
189124
190- private static class AssertingPartyMetadataParametersMapper
191- implements Function <AssertingPartyMetadata , List <SqlParameterValue >> {
192-
193- private final Logger logger = LoggerFactory .getLogger (AssertingPartyMetadataParametersMapper .class );
194-
195- private final Serializer <Object > serializer = new DefaultSerializer ();
196-
197- @ Override
198- public List <SqlParameterValue > apply (AssertingPartyMetadata record ) {
199- List <SqlParameterValue > parameters = new ArrayList <>();
200- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getEntityId ()));
201- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleSignOnServiceLocation ()));
202- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleSignOnServiceBinding ().getUrn ()));
203- parameters .add (new SqlParameterValue (Types .BOOLEAN , record .getWantAuthnRequestsSigned ()));
204- try {
205- parameters .add (new SqlParameterValue (Types .BLOB ,
206- this .serializer .serializeToByteArray (record .getSigningAlgorithms ())));
207- } catch (IOException ex ) {
208- this .logger .debug ("Failed to serialize signing algorithms" , ex );
209- throw new IllegalArgumentException (ex );
210- }
211- try {
212- parameters .add (new SqlParameterValue (Types .BLOB ,
213- this .serializer .serializeToByteArray (record .getVerificationX509Credentials ())));
214- } catch (IOException ex ) {
215- this .logger .debug ("Failed to serialize verification credentials" , ex );
216- throw new IllegalArgumentException (ex );
217- }
218- try {
219- parameters .add (new SqlParameterValue (Types .BLOB ,
220- this .serializer .serializeToByteArray (record .getEncryptionX509Credentials ())));
221- } catch (IOException ex ) {
222- this .logger .debug ("Failed to serialize encryption credentials" , ex );
223- throw new IllegalArgumentException (ex );
224- }
225- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceLocation ()));
226- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceResponseLocation ()));
227- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceBinding ().getUrn ()));
228- return parameters ;
229- }
230- }
231-
232- private static final class BlobArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
233-
234- private final SetBytes setBytes ;
235-
236- private BlobArgumentPreparedStatementSetter (SetBytes setBytes , Object [] args ) {
237- super (args );
238- this .setBytes = setBytes ;
239- }
240-
241- @ Override
242- protected void doSetValue (PreparedStatement ps , int parameterPosition , Object argValue ) throws SQLException {
243- if (argValue instanceof SqlParameterValue paramValue ) {
244- if (paramValue .getSqlType () == Types .BLOB ) {
245- if (paramValue .getValue () != null ) {
246- Assert .isInstanceOf (byte [].class , paramValue .getValue (),
247- "Value of blob parameter must be byte[]" );
248- }
249- byte [] valueBytes = (byte []) paramValue .getValue ();
250- this .setBytes .setBytes (ps , parameterPosition , valueBytes );
251- return ;
252- }
253- }
254- super .doSetValue (ps , parameterPosition , argValue );
255- }
256-
257- }
258-
259125 /**
260126 * The default {@link RowMapper} that maps the current row in
261127 * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
@@ -275,61 +141,68 @@ private final static class AssertingPartyMetadataRowMapper implements RowMapper<
275141 @ Override
276142 public AssertingPartyMetadata mapRow (ResultSet rs , int rowNum ) throws SQLException {
277143 String entityId = rs .getString ("entity_id" );
144+ String metadataUri = rs .getString ("metadata_uri" );
278145 String singleSignOnUrl = rs .getString ("singlesignon_url" );
279- Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding
280- .from (rs .getString ("singlesignon_binding" ));
146+ Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding .from (rs .getString ("singlesignon_binding" ));
281147 boolean singleSignOnSignRequest = rs .getBoolean ("singlesignon_sign_request" );
282- List <String > signingAlgorithms ;
283- try {
284- signingAlgorithms = (List <String >) deserializer .deserializeFromByteArray (
285- this .getBytes .getBytes (rs , "signing_algorithms" ));
286- } catch (IOException ex ) {
287- this .logger .debug (
288- LogMessage .format ("Verification credentials of %s could not be parsed." , entityId ), ex );
289- return null ;
290- }
291- Collection <Saml2X509Credential > verificationCredentials ;
292- try {
293- verificationCredentials = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (
294- this .getBytes .getBytes (rs , "verification_credentials" ));
295- } catch (IOException ex ) {
296- this .logger .debug (
297- LogMessage .format ("Verification credentials of %s could not be parsed." , entityId ), ex );
298- return null ;
299- }
300- Collection <Saml2X509Credential > encryptionCredentials ;
148+ String singleLogoutUrl = rs .getString ("singlelogout_url" );
149+ String singleLogoutResponseUrl = rs .getString ("singlelogout_response_url" );
150+ Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding .from (rs .getString ("singlelogout_binding" ));
151+ byte [] signingAlgorithmsBytes = this .getBytes .getBytes (rs , "signing_algorithms" );
152+ byte [] verificationCredentialsBytes = this .getBytes .getBytes (rs , "verification_credentials" );
153+ byte [] encryptionCredentialsBytes = this .getBytes .getBytes (rs , "encryption_credentials" );
154+
155+ boolean usingMetadata = StringUtils .hasText (metadataUri );
156+ AssertingPartyMetadata .Builder <?> builder = (!usingMetadata ) ? new AssertingPartyDetails .Builder ().entityId (entityId )
157+ : createBuilderUsingMetadata (entityId , metadataUri );
301158 try {
302- encryptionCredentials = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (
303- this .getBytes .getBytes (rs , "encryption_credentials" ));
304- } catch (IOException ex ) {
159+ if (signingAlgorithmsBytes != null ) {
160+ List <String > signingAlgorithms = (List <String >) deserializer .deserializeFromByteArray (signingAlgorithmsBytes );
161+ builder .signingAlgorithms (algorithms -> algorithms .addAll (signingAlgorithms ));
162+ }
163+ if (verificationCredentialsBytes != null ) {
164+ Collection <Saml2X509Credential > verificationCredentials = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (verificationCredentialsBytes );
165+ builder .verificationX509Credentials (credentials -> credentials .addAll (verificationCredentials ));
166+ }
167+ if (encryptionCredentialsBytes != null ) {
168+ Collection <Saml2X509Credential > encryptionCredentials = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (encryptionCredentialsBytes );
169+ builder .encryptionX509Credentials (credentials -> credentials .addAll (encryptionCredentials ));
170+ }
171+ } catch (Exception ex ) {
305172 this .logger .debug (
306- LogMessage .format ("Encryption credentials of %s could not be parsed. " , entityId ), ex );
173+ LogMessage .format ("Parsing serialized credentials for entity %s failed " , entityId ), ex );
307174 return null ;
308175 }
309- String singleLogoutUrl = rs .getString ("singlelogout_url" );
310- String singleLogoutResponseUrl = rs .getString ("singlelogout_response_url" );
311- Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding
312- .from (rs .getString ("singlelogout_binding" ));
313176
314- return new AssertingPartyDetails .Builder ()
315- .entityId (entityId )
316- .wantAuthnRequestsSigned (singleSignOnSignRequest )
317- .signingAlgorithms (algorithms -> algorithms .addAll (signingAlgorithms ))
318- .verificationX509Credentials (credentials -> credentials .addAll (verificationCredentials ))
319- .encryptionX509Credentials (credentials -> credentials .addAll (encryptionCredentials ))
320- .singleSignOnServiceLocation (singleSignOnUrl )
321- .singleSignOnServiceBinding (singleSignOnBinding )
322- .singleLogoutServiceLocation (singleLogoutUrl )
323- .singleLogoutServiceBinding (singleLogoutBinding )
324- .singleLogoutServiceResponseLocation (singleLogoutResponseUrl )
325- .build ();
177+ applyingWhenNonNull (singleSignOnUrl , builder ::singleSignOnServiceLocation );
178+ applyingWhenNonNull (singleSignOnBinding , builder ::singleSignOnServiceBinding );
179+ applyingWhenNonNull (singleSignOnSignRequest , builder ::wantAuthnRequestsSigned );
180+ applyingWhenNonNull (singleLogoutUrl , builder ::singleLogoutServiceLocation );
181+ applyingWhenNonNull (singleLogoutResponseUrl , builder ::singleLogoutServiceResponseLocation );
182+ applyingWhenNonNull (singleLogoutBinding , builder ::singleLogoutServiceBinding );
183+ return builder .build ();
326184 }
327- }
328185
329- private interface SetBytes {
186+ private <T > void applyingWhenNonNull (T value , Consumer <T > consumer ) {
187+ if (value != null ) {
188+ consumer .accept (value );
189+ }
190+ }
330191
331- void setBytes (PreparedStatement ps , int index , byte [] bytes ) throws SQLException ;
192+ private AssertingPartyMetadata .Builder <?> createBuilderUsingMetadata (String entityId , String metadataUri ) {
193+ Collection <AssertingPartyMetadata .Builder <?>> candidates = AssertingPartyMetadata
194+ .collectionFromMetadataLocation (metadataUri );
195+ for (AssertingPartyMetadata .Builder <?> candidate : candidates ) {
196+ if (entityId == null || entityId .equals (getEntityId (candidate ))) {
197+ return candidate ;
198+ }
199+ }
200+ throw new IllegalStateException ("No asserting party metadata with Entity ID '" + entityId + "' found" );
201+ }
332202
203+ private Object getEntityId (AssertingPartyMetadata .Builder <?> candidate ) {
204+ return candidate .build ().getEntityId ();
205+ }
333206 }
334207
335208 private interface GetBytes {
0 commit comments