@@ -90,10 +90,10 @@ asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
9090 int rounds , int blocks , u8 ctr []);
9191
9292asmlinkage void aes_xts_encrypt (u8 out [], u8 const in [], u32 const rk1 [],
93- int rounds , int blocks , u32 const rk2 [], u8 iv [],
93+ int rounds , int bytes , u32 const rk2 [], u8 iv [],
9494 int first );
9595asmlinkage void aes_xts_decrypt (u8 out [], u8 const in [], u32 const rk1 [],
96- int rounds , int blocks , u32 const rk2 [], u8 iv [],
96+ int rounds , int bytes , u32 const rk2 [], u8 iv [],
9797 int first );
9898
9999asmlinkage void aes_essiv_cbc_encrypt (u8 out [], u8 const in [], u32 const rk1 [],
@@ -527,43 +527,144 @@ static int __maybe_unused xts_encrypt(struct skcipher_request *req)
527527 struct crypto_skcipher * tfm = crypto_skcipher_reqtfm (req );
528528 struct crypto_aes_xts_ctx * ctx = crypto_skcipher_ctx (tfm );
529529 int err , first , rounds = 6 + ctx -> key1 .key_length / 4 ;
530+ int tail = req -> cryptlen % AES_BLOCK_SIZE ;
531+ struct scatterlist sg_src [2 ], sg_dst [2 ];
532+ struct skcipher_request subreq ;
533+ struct scatterlist * src , * dst ;
530534 struct skcipher_walk walk ;
531- unsigned int blocks ;
535+
536+ if (req -> cryptlen < AES_BLOCK_SIZE )
537+ return - EINVAL ;
532538
533539 err = skcipher_walk_virt (& walk , req , false);
534540
535- for (first = 1 ; (blocks = (walk .nbytes / AES_BLOCK_SIZE )); first = 0 ) {
541+ if (unlikely (tail > 0 && walk .nbytes < walk .total )) {
542+ int xts_blocks = DIV_ROUND_UP (req -> cryptlen ,
543+ AES_BLOCK_SIZE ) - 2 ;
544+
545+ skcipher_walk_abort (& walk );
546+
547+ skcipher_request_set_tfm (& subreq , tfm );
548+ skcipher_request_set_callback (& subreq ,
549+ skcipher_request_flags (req ),
550+ NULL , NULL );
551+ skcipher_request_set_crypt (& subreq , req -> src , req -> dst ,
552+ xts_blocks * AES_BLOCK_SIZE ,
553+ req -> iv );
554+ req = & subreq ;
555+ err = skcipher_walk_virt (& walk , req , false);
556+ } else {
557+ tail = 0 ;
558+ }
559+
560+ for (first = 1 ; walk .nbytes >= AES_BLOCK_SIZE ; first = 0 ) {
561+ int nbytes = walk .nbytes ;
562+
563+ if (walk .nbytes < walk .total )
564+ nbytes &= ~(AES_BLOCK_SIZE - 1 );
565+
536566 kernel_neon_begin ();
537567 aes_xts_encrypt (walk .dst .virt .addr , walk .src .virt .addr ,
538- ctx -> key1 .key_enc , rounds , blocks ,
568+ ctx -> key1 .key_enc , rounds , nbytes ,
539569 ctx -> key2 .key_enc , walk .iv , first );
540570 kernel_neon_end ();
541- err = skcipher_walk_done (& walk , walk .nbytes % AES_BLOCK_SIZE );
571+ err = skcipher_walk_done (& walk , walk .nbytes - nbytes );
542572 }
543573
544- return err ;
574+ if (err || likely (!tail ))
575+ return err ;
576+
577+ dst = src = scatterwalk_ffwd (sg_src , req -> src , req -> cryptlen );
578+ if (req -> dst != req -> src )
579+ dst = scatterwalk_ffwd (sg_dst , req -> dst , req -> cryptlen );
580+
581+ skcipher_request_set_crypt (req , src , dst , AES_BLOCK_SIZE + tail ,
582+ req -> iv );
583+
584+ err = skcipher_walk_virt (& walk , & subreq , false);
585+ if (err )
586+ return err ;
587+
588+ kernel_neon_begin ();
589+ aes_xts_encrypt (walk .dst .virt .addr , walk .src .virt .addr ,
590+ ctx -> key1 .key_enc , rounds , walk .nbytes ,
591+ ctx -> key2 .key_enc , walk .iv , first );
592+ kernel_neon_end ();
593+
594+ return skcipher_walk_done (& walk , 0 );
545595}
546596
547597static int __maybe_unused xts_decrypt (struct skcipher_request * req )
548598{
549599 struct crypto_skcipher * tfm = crypto_skcipher_reqtfm (req );
550600 struct crypto_aes_xts_ctx * ctx = crypto_skcipher_ctx (tfm );
551601 int err , first , rounds = 6 + ctx -> key1 .key_length / 4 ;
602+ int tail = req -> cryptlen % AES_BLOCK_SIZE ;
603+ struct scatterlist sg_src [2 ], sg_dst [2 ];
604+ struct skcipher_request subreq ;
605+ struct scatterlist * src , * dst ;
552606 struct skcipher_walk walk ;
553- unsigned int blocks ;
607+
608+ if (req -> cryptlen < AES_BLOCK_SIZE )
609+ return - EINVAL ;
554610
555611 err = skcipher_walk_virt (& walk , req , false);
556612
557- for (first = 1 ; (blocks = (walk .nbytes / AES_BLOCK_SIZE )); first = 0 ) {
613+ if (unlikely (tail > 0 && walk .nbytes < walk .total )) {
614+ int xts_blocks = DIV_ROUND_UP (req -> cryptlen ,
615+ AES_BLOCK_SIZE ) - 2 ;
616+
617+ skcipher_walk_abort (& walk );
618+
619+ skcipher_request_set_tfm (& subreq , tfm );
620+ skcipher_request_set_callback (& subreq ,
621+ skcipher_request_flags (req ),
622+ NULL , NULL );
623+ skcipher_request_set_crypt (& subreq , req -> src , req -> dst ,
624+ xts_blocks * AES_BLOCK_SIZE ,
625+ req -> iv );
626+ req = & subreq ;
627+ err = skcipher_walk_virt (& walk , req , false);
628+ } else {
629+ tail = 0 ;
630+ }
631+
632+ for (first = 1 ; walk .nbytes >= AES_BLOCK_SIZE ; first = 0 ) {
633+ int nbytes = walk .nbytes ;
634+
635+ if (walk .nbytes < walk .total )
636+ nbytes &= ~(AES_BLOCK_SIZE - 1 );
637+
558638 kernel_neon_begin ();
559639 aes_xts_decrypt (walk .dst .virt .addr , walk .src .virt .addr ,
560- ctx -> key1 .key_dec , rounds , blocks ,
640+ ctx -> key1 .key_dec , rounds , nbytes ,
561641 ctx -> key2 .key_enc , walk .iv , first );
562642 kernel_neon_end ();
563- err = skcipher_walk_done (& walk , walk .nbytes % AES_BLOCK_SIZE );
643+ err = skcipher_walk_done (& walk , walk .nbytes - nbytes );
564644 }
565645
566- return err ;
646+ if (err || likely (!tail ))
647+ return err ;
648+
649+ dst = src = scatterwalk_ffwd (sg_src , req -> src , req -> cryptlen );
650+ if (req -> dst != req -> src )
651+ dst = scatterwalk_ffwd (sg_dst , req -> dst , req -> cryptlen );
652+
653+ skcipher_request_set_crypt (req , src , dst , AES_BLOCK_SIZE + tail ,
654+ req -> iv );
655+
656+ err = skcipher_walk_virt (& walk , & subreq , false);
657+ if (err )
658+ return err ;
659+
660+
661+ kernel_neon_begin ();
662+ aes_xts_decrypt (walk .dst .virt .addr , walk .src .virt .addr ,
663+ ctx -> key1 .key_dec , rounds , walk .nbytes ,
664+ ctx -> key2 .key_enc , walk .iv , first );
665+ kernel_neon_end ();
666+
667+ return skcipher_walk_done (& walk , 0 );
567668}
568669
569670static struct skcipher_alg aes_algs [] = { {
@@ -644,6 +745,7 @@ static struct skcipher_alg aes_algs[] = { {
644745 .min_keysize = 2 * AES_MIN_KEY_SIZE ,
645746 .max_keysize = 2 * AES_MAX_KEY_SIZE ,
646747 .ivsize = AES_BLOCK_SIZE ,
748+ .walksize = 2 * AES_BLOCK_SIZE ,
647749 .setkey = xts_set_key ,
648750 .encrypt = xts_encrypt ,
649751 .decrypt = xts_decrypt ,
0 commit comments