4
4
PluginOptions ,
5
5
createProject ,
6
6
emitProject ,
7
- getAttribute ,
8
- getAttributeArg ,
9
7
getDataModels ,
10
8
getLiteral ,
11
9
getPrismaClientImportSpec ,
@@ -17,16 +15,7 @@ import {
17
15
resolvePath ,
18
16
saveProject ,
19
17
} from '@zenstackhq/sdk' ;
20
- import {
21
- DataModel ,
22
- DataModelField ,
23
- DataSource ,
24
- EnumField ,
25
- Model ,
26
- isDataModel ,
27
- isDataSource ,
28
- isEnum ,
29
- } from '@zenstackhq/sdk/ast' ;
18
+ import { DataModel , DataSource , EnumField , Model , isDataModel , isDataSource , isEnum } from '@zenstackhq/sdk/ast' ;
30
19
import { addMissingInputObjectTypes , resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers' ;
31
20
import { promises as fs } from 'fs' ;
32
21
import { streamAllContents } from 'langium' ;
@@ -271,18 +260,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
271
260
overwrite : true ,
272
261
} ) ;
273
262
sf . replaceWithText ( ( writer ) => {
274
- const fields = model . fields . filter (
263
+ const scalarFields = model . fields . filter (
275
264
( field ) =>
276
265
// regular fields only
277
266
! isDataModel ( field . type . reference ?. ref ) && ! isForeignKeyField ( field )
278
267
) ;
279
268
280
269
const relations = model . fields . filter ( ( field ) => isDataModel ( field . type . reference ?. ref ) ) ;
281
270
const fkFields = model . fields . filter ( ( field ) => isForeignKeyField ( field ) ) ;
282
- // unsafe version of relations: including foreign keys and relation fields without fk
283
- const unsafeRelations = model . fields . filter (
284
- ( field ) => isForeignKeyField ( field ) || ( isDataModel ( field . type . reference ?. ref ) && ! hasForeignKey ( field ) )
285
- ) ;
286
271
287
272
writer . writeLine ( '/* eslint-disable */' ) ;
288
273
writer . writeLine ( `import { z } from 'zod';` ) ;
@@ -304,7 +289,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
304
289
305
290
// import enum schemas
306
291
const importedEnumSchemas = new Set < string > ( ) ;
307
- for ( const field of fields ) {
292
+ for ( const field of scalarFields ) {
308
293
if ( field . type . reference ?. ref && isEnum ( field . type . reference ?. ref ) ) {
309
294
const name = upperCaseFirst ( field . type . reference ?. ref . name ) ;
310
295
if ( ! importedEnumSchemas . has ( name ) ) {
@@ -315,29 +300,28 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
315
300
}
316
301
317
302
// import Decimal
318
- if ( fields . some ( ( field ) => field . type . type === 'Decimal' ) ) {
303
+ if ( scalarFields . some ( ( field ) => field . type . type === 'Decimal' ) ) {
319
304
writer . writeLine ( `import { DecimalSchema } from '../common';` ) ;
320
305
writer . writeLine ( `import { Decimal } from 'decimal.js';` ) ;
321
306
}
322
307
323
308
// base schema
324
309
writer . write ( `const baseSchema = z.object(` ) ;
325
310
writer . inlineBlock ( ( ) => {
326
- fields . forEach ( ( field ) => {
311
+ scalarFields . forEach ( ( field ) => {
327
312
writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
328
313
} ) ;
329
314
} ) ;
330
315
writer . writeLine ( ');' ) ;
331
316
332
317
// relation fields
333
318
334
- let allRelationSchema : string | undefined ;
335
- let safeRelationSchema : string | undefined ;
336
- let unsafeRelationSchema : string | undefined ;
319
+ let relationSchema : string | undefined ;
320
+ let fkSchema : string | undefined ;
337
321
338
322
if ( relations . length > 0 || fkFields . length > 0 ) {
339
- allRelationSchema = 'allRelationSchema ' ;
340
- writer . write ( `const ${ allRelationSchema } = z.object(` ) ;
323
+ relationSchema = 'relationSchema ' ;
324
+ writer . write ( `const ${ relationSchema } = z.object(` ) ;
341
325
writer . inlineBlock ( ( ) => {
342
326
[ ...relations , ...fkFields ] . forEach ( ( field ) => {
343
327
writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
@@ -346,23 +330,12 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
346
330
writer . writeLine ( ');' ) ;
347
331
}
348
332
349
- if ( relations . length > 0 ) {
350
- safeRelationSchema = 'safeRelationSchema ' ;
351
- writer . write ( `const ${ safeRelationSchema } = z.object(` ) ;
333
+ if ( fkFields . length > 0 ) {
334
+ fkSchema = 'fkSchema ' ;
335
+ writer . write ( `const ${ fkSchema } = z.object(` ) ;
352
336
writer . inlineBlock ( ( ) => {
353
- relations . forEach ( ( field ) => {
354
- writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field , true ) } ,` ) ;
355
- } ) ;
356
- } ) ;
357
- writer . writeLine ( ');' ) ;
358
- }
359
-
360
- if ( unsafeRelations . length > 0 ) {
361
- unsafeRelationSchema = 'unsafeRelationSchema' ;
362
- writer . write ( `const ${ unsafeRelationSchema } = z.object(` ) ;
363
- writer . inlineBlock ( ( ) => {
364
- unsafeRelations . forEach ( ( field ) => {
365
- writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field , true ) } ,` ) ;
337
+ fkFields . forEach ( ( field ) => {
338
+ writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
366
339
} ) ;
367
340
} ) ;
368
341
writer . writeLine ( ');' ) ;
@@ -383,25 +356,25 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
383
356
////////////////////////////////////////////////
384
357
// 1. Model schema
385
358
////////////////////////////////////////////////
386
- let modelSchema = 'baseSchema' ;
359
+ let modelSchema = makePartial ( 'baseSchema' ) ;
387
360
388
361
// omit fields
389
- const fieldsToOmit = fields . filter ( ( field ) => hasAttribute ( field , '@omit' ) ) ;
362
+ const fieldsToOmit = scalarFields . filter ( ( field ) => hasAttribute ( field , '@omit' ) ) ;
390
363
if ( fieldsToOmit . length > 0 ) {
391
364
modelSchema = makeOmit (
392
365
modelSchema ,
393
366
fieldsToOmit . map ( ( f ) => f . name )
394
367
) ;
395
368
}
396
369
397
- if ( allRelationSchema ) {
370
+ if ( relationSchema ) {
398
371
// export schema with only scalar fields
399
372
const modelScalarSchema = `${ upperCaseFirst ( model . name ) } ScalarSchema` ;
400
373
writer . writeLine ( `export const ${ modelScalarSchema } = ${ modelSchema } ;` ) ;
401
374
modelSchema = modelScalarSchema ;
402
375
403
376
// merge relations
404
- modelSchema = makeMerge ( modelSchema , allRelationSchema ) ;
377
+ modelSchema = makeMerge ( modelSchema , makePartial ( relationSchema ) ) ;
405
378
}
406
379
407
380
// refine
@@ -413,10 +386,40 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
413
386
writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } Schema = ${ modelSchema } ;` ) ;
414
387
415
388
////////////////////////////////////////////////
416
- // 2. Create schema
389
+ // 2. Prisma create & update
390
+ ////////////////////////////////////////////////
391
+
392
+ // schema for validating prisma create input (all fields optional)
393
+ let prismaCreateSchema = makePartial ( 'baseSchema' ) ;
394
+ if ( refineFuncName ) {
395
+ prismaCreateSchema = `${ refineFuncName } (${ prismaCreateSchema } )` ;
396
+ }
397
+ writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } PrismaCreateSchema = ${ prismaCreateSchema } ;` ) ;
398
+
399
+ // schema for validating prisma update input (all fields optional)
400
+ // note numeric fields can be simple update or atomic operations
401
+ let prismaUpdateSchema = `z.object({
402
+ ${ scalarFields
403
+ . map ( ( field ) => {
404
+ let fieldSchema = makeFieldSchema ( field ) ;
405
+ if ( field . type . type === 'Int' || field . type . type === 'Float' ) {
406
+ fieldSchema = `z.union([${ fieldSchema } , z.record(z.unknown())])` ;
407
+ }
408
+ return `\t${ field . name } : ${ fieldSchema } ` ;
409
+ } )
410
+ . join ( ',\n' ) }
411
+ })` ;
412
+ prismaUpdateSchema = makePartial ( prismaUpdateSchema ) ;
413
+ if ( refineFuncName ) {
414
+ prismaUpdateSchema = `${ refineFuncName } (${ prismaUpdateSchema } )` ;
415
+ }
416
+ writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } PrismaUpdateSchema = ${ prismaUpdateSchema } ;` ) ;
417
+
418
+ ////////////////////////////////////////////////
419
+ // 3. Create schema
417
420
////////////////////////////////////////////////
418
421
let createSchema = 'baseSchema' ;
419
- const fieldsWithDefault = fields . filter (
422
+ const fieldsWithDefault = scalarFields . filter (
420
423
( field ) => hasAttribute ( field , '@default' ) || hasAttribute ( field , '@updatedAt' ) || field . type . array
421
424
) ;
422
425
if ( fieldsWithDefault . length > 0 ) {
@@ -426,30 +429,13 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
426
429
) ;
427
430
}
428
431
429
- if ( safeRelationSchema || unsafeRelationSchema ) {
432
+ if ( fkSchema ) {
430
433
// export schema with only scalar fields
431
434
const createScalarSchema = `${ upperCaseFirst ( model . name ) } CreateScalarSchema` ;
432
435
writer . writeLine ( `export const ${ createScalarSchema } = ${ createSchema } ;` ) ;
433
- createSchema = createScalarSchema ;
434
-
435
- if ( safeRelationSchema && unsafeRelationSchema ) {
436
- // build a union of with relation object fields and with fk fields (mutually exclusive)
437
-
438
- // TODO: we make all relation fields partial for now because in case of
439
- // nested create, not all relation/fk fields are inside payload, need a
440
- // better solution
441
- createSchema = makeUnion (
442
- makeMerge ( createSchema , makePartial ( safeRelationSchema ) ) ,
443
- makeMerge ( createSchema , makePartial ( unsafeRelationSchema ) )
444
- ) ;
445
- } else if ( safeRelationSchema ) {
446
- // just relation
447
-
448
- // TODO: we make all relation fields partial for now because in case of
449
- // nested create, not all relation/fk fields are inside payload, need a
450
- // better solution
451
- createSchema = makeMerge ( createSchema , makePartial ( safeRelationSchema ) ) ;
452
- }
436
+
437
+ // merge fk fields
438
+ createSchema = makeMerge ( createScalarSchema , fkSchema ) ;
453
439
}
454
440
455
441
if ( refineFuncName ) {
@@ -465,22 +451,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
465
451
////////////////////////////////////////////////
466
452
let updateSchema = makePartial ( 'baseSchema' ) ;
467
453
468
- if ( safeRelationSchema || unsafeRelationSchema ) {
454
+ if ( fkSchema ) {
469
455
// export schema with only scalar fields
470
456
const updateScalarSchema = `${ upperCaseFirst ( model . name ) } UpdateScalarSchema` ;
471
457
writer . writeLine ( `export const ${ updateScalarSchema } = ${ updateSchema } ;` ) ;
472
458
updateSchema = updateScalarSchema ;
473
459
474
- if ( safeRelationSchema && unsafeRelationSchema ) {
475
- // build a union of with relation object fields and with fk fields (mutually exclusive)
476
- updateSchema = makeUnion (
477
- makeMerge ( updateSchema , makePartial ( safeRelationSchema ) ) ,
478
- makeMerge ( updateSchema , makePartial ( unsafeRelationSchema ) )
479
- ) ;
480
- } else if ( safeRelationSchema ) {
481
- // just relation
482
- updateSchema = makeMerge ( updateSchema , makePartial ( safeRelationSchema ) ) ;
483
- }
460
+ // merge fk fields
461
+ updateSchema = makeMerge ( updateSchema , makePartial ( fkSchema ) ) ;
484
462
}
485
463
486
464
if ( refineFuncName ) {
@@ -514,15 +492,3 @@ function makeOmit(schema: string, fields: string[]) {
514
492
function makeMerge ( schema1 : string , schema2 : string ) : string {
515
493
return `${ schema1 } .merge(${ schema2 } )` ;
516
494
}
517
-
518
- function makeUnion ( ...schemas : string [ ] ) : string {
519
- return `z.union([${ schemas . join ( ', ' ) } ])` ;
520
- }
521
-
522
- function hasForeignKey ( field : DataModelField ) {
523
- const relAttr = getAttribute ( field , '@relation' ) ;
524
- if ( ! relAttr ) {
525
- return false ;
526
- }
527
- return ! ! getAttributeArg ( relAttr , 'fields' ) ;
528
- }
0 commit comments