@@ -467,7 +467,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
467
467
// Validates the given create payload against Zod schema if any
468
468
private validateCreateInputSchema ( model : string , data : any ) {
469
469
const schema = this . utils . getZodSchema ( model , 'create' ) ;
470
- if ( schema ) {
470
+ if ( schema && data ) {
471
471
const parseResult = schema . safeParse ( data ) ;
472
472
if ( ! parseResult . success ) {
473
473
throw this . utils . deniedByPolicy (
@@ -496,26 +496,29 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
496
496
497
497
args = this . utils . clone ( args ) ;
498
498
499
- // do static input validation and check if post-create checks are needed
499
+ // go through create items, statically check input to determine if post-create
500
+ // check is needed, and also validate zod schema
500
501
let needPostCreateCheck = false ;
501
502
for ( const item of enumerate ( args . data ) ) {
503
+ const validationResult = this . validateCreateInputSchema ( this . model , item ) ;
504
+ if ( validationResult !== item ) {
505
+ this . utils . replace ( item , validationResult ) ;
506
+ }
507
+
502
508
const inputCheck = this . utils . checkInputGuard ( this . model , item , 'create' ) ;
503
509
if ( inputCheck === false ) {
510
+ // unconditionally deny
504
511
throw this . utils . deniedByPolicy (
505
512
this . model ,
506
513
'create' ,
507
514
undefined ,
508
515
CrudFailureReason . ACCESS_POLICY_VIOLATION
509
516
) ;
510
517
} else if ( inputCheck === true ) {
511
- const r = this . validateCreateInputSchema ( this . model , item ) ;
512
- if ( r !== item ) {
513
- this . utils . replace ( item , r ) ;
514
- }
518
+ // unconditionally allow
515
519
} else if ( inputCheck === undefined ) {
516
520
// static policy check is not possible, need to do post-create check
517
521
needPostCreateCheck = true ;
518
- break ;
519
522
}
520
523
}
521
524
@@ -786,7 +789,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
786
789
787
790
// check if the update actually writes to this model
788
791
let thisModelUpdate = false ;
789
- const updatePayload : any = ( args as any ) . data ?? args ;
792
+ const updatePayload = ( args as any ) . data ?? args ;
793
+
794
+ const validatedPayload = this . validateUpdateInputSchema ( model , updatePayload ) ;
795
+ if ( validatedPayload !== updatePayload ) {
796
+ this . utils . replace ( updatePayload , validatedPayload ) ;
797
+ }
798
+
790
799
if ( updatePayload ) {
791
800
for ( const key of Object . keys ( updatePayload ) ) {
792
801
const field = resolveField ( this . modelMeta , model , key ) ;
@@ -857,6 +866,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
857
866
) ;
858
867
}
859
868
869
+ args . data = this . validateUpdateInputSchema ( model , args . data ) ;
870
+
860
871
const updateGuard = this . utils . getAuthGuard ( db , model , 'update' ) ;
861
872
if ( this . utils . isTrue ( updateGuard ) || this . utils . isFalse ( updateGuard ) ) {
862
873
// injects simple auth guard into where clause
@@ -917,7 +928,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
917
928
await _registerPostUpdateCheck ( model , uniqueFilter ) ;
918
929
919
930
// convert upsert to update
920
- context . parent . update = { where : args . where , data : args . update } ;
931
+ context . parent . update = {
932
+ where : args . where ,
933
+ data : this . validateUpdateInputSchema ( model , args . update ) ,
934
+ } ;
921
935
delete context . parent . upsert ;
922
936
923
937
// continue visiting the new payload
@@ -1016,6 +1030,37 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
1016
1030
return { result, postWriteChecks } ;
1017
1031
}
1018
1032
1033
+ // Validates the given update payload against Zod schema if any
1034
+ private validateUpdateInputSchema ( model : string , data : any ) {
1035
+ const schema = this . utils . getZodSchema ( model , 'update' ) ;
1036
+ if ( schema && data ) {
1037
+ // update payload can contain non-literal fields, like:
1038
+ // { x: { increment: 1 } }
1039
+ // we should only validate literal fields
1040
+
1041
+ const literalData = Object . entries ( data ) . reduce < any > (
1042
+ ( acc , [ k , v ] ) => ( { ...acc , ...( typeof v !== 'object' ? { [ k ] : v } : { } ) } ) ,
1043
+ { }
1044
+ ) ;
1045
+
1046
+ const parseResult = schema . safeParse ( literalData ) ;
1047
+ if ( ! parseResult . success ) {
1048
+ throw this . utils . deniedByPolicy (
1049
+ model ,
1050
+ 'update' ,
1051
+ `input failed validation: ${ fromZodError ( parseResult . error ) } ` ,
1052
+ CrudFailureReason . DATA_VALIDATION_VIOLATION ,
1053
+ parseResult . error
1054
+ ) ;
1055
+ }
1056
+
1057
+ // schema may have transformed field values, use it to overwrite the original data
1058
+ return { ...data , ...parseResult . data } ;
1059
+ } else {
1060
+ return data ;
1061
+ }
1062
+ }
1063
+
1019
1064
private isUnsafeMutate ( model : string , args : any ) {
1020
1065
if ( ! args ) {
1021
1066
return false ;
@@ -1046,6 +1091,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
1046
1091
args = this . utils . clone ( args ) ;
1047
1092
this . utils . injectAuthGuardAsWhere ( this . prisma , args , this . model , 'update' ) ;
1048
1093
1094
+ args . data = this . validateUpdateInputSchema ( this . model , args . data ) ;
1095
+
1049
1096
if ( this . utils . hasAuthGuard ( this . model , 'postUpdate' ) || this . utils . getZodSchema ( this . model ) ) {
1050
1097
// use a transaction to do post-update checks
1051
1098
const postWriteChecks : PostWriteCheckRecord [ ] = [ ] ;
0 commit comments