Skip to content

Commit 9d4a8ed

Browse files
authored
fix: clean up zod generation (#883)
1 parent aa705a4 commit 9d4a8ed

File tree

17 files changed

+114
-117
lines changed

17 files changed

+114
-117
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "zenstack-monorepo",
3-
"version": "1.4.0",
3+
"version": "1.4.1",
44
"description": "",
55
"scripts": {
66
"build": "pnpm -r build",

packages/language/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@zenstackhq/language",
3-
"version": "1.4.0",
3+
"version": "1.4.1",
44
"displayName": "ZenStack modeling language compiler",
55
"description": "ZenStack modeling language compiler",
66
"homepage": "https://zenstack.dev",

packages/plugins/openapi/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/openapi",
33
"displayName": "ZenStack Plugin and Runtime for OpenAPI",
4-
"version": "1.4.0",
4+
"version": "1.4.1",
55
"description": "ZenStack plugin and runtime supporting OpenAPI",
66
"main": "index.js",
77
"repository": {

packages/plugins/swr/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/swr",
33
"displayName": "ZenStack plugin for generating SWR hooks",
4-
"version": "1.4.0",
4+
"version": "1.4.1",
55
"description": "ZenStack plugin for generating SWR hooks",
66
"main": "index.js",
77
"repository": {

packages/plugins/tanstack-query/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/tanstack-query",
33
"displayName": "ZenStack plugin for generating tanstack-query hooks",
4-
"version": "1.4.0",
4+
"version": "1.4.1",
55
"description": "ZenStack plugin for generating tanstack-query hooks",
66
"main": "index.js",
77
"exports": {

packages/plugins/trpc/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/trpc",
33
"displayName": "ZenStack plugin for tRPC",
4-
"version": "1.4.0",
4+
"version": "1.4.1",
55
"description": "ZenStack plugin for tRPC",
66
"main": "index.js",
77
"repository": {

packages/runtime/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/runtime",
33
"displayName": "ZenStack Runtime Library",
4-
"version": "1.4.0",
4+
"version": "1.4.1",
55
"description": "Runtime of ZenStack for both client-side and server-side environments.",
66
"repository": {
77
"type": "git",

packages/runtime/src/enhancements/policy/policy-utils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ export class PolicyUtil {
10511051
if (!this.hasFieldValidation(model)) {
10521052
return undefined;
10531053
}
1054-
const schemaKey = `${upperCaseFirst(model)}${kind ? upperCaseFirst(kind) : ''}Schema`;
1054+
const schemaKey = `${upperCaseFirst(model)}${kind ? 'Prisma' + upperCaseFirst(kind) : ''}Schema`;
10551055
return this.zodSchemas?.models?.[schemaKey];
10561056
}
10571057

packages/schema/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"publisher": "zenstack",
44
"displayName": "ZenStack Language Tools",
55
"description": "Build scalable web apps with minimum code by defining authorization and validation rules inside the data schema that closer to the database",
6-
"version": "1.4.0",
6+
"version": "1.4.1",
77
"author": {
88
"name": "ZenStack Team"
99
},

packages/schema/src/plugins/zod/generator.ts

Lines changed: 57 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import {
44
PluginOptions,
55
createProject,
66
emitProject,
7-
getAttribute,
8-
getAttributeArg,
97
getDataModels,
108
getLiteral,
119
getPrismaClientImportSpec,
@@ -17,16 +15,7 @@ import {
1715
resolvePath,
1816
saveProject,
1917
} from '@zenstackhq/sdk';
20-
import {
21-
DataModel,
22-
DataModelField,
23-
DataSource,
24-
EnumField,
25-
Model,
26-
isDataModel,
27-
isDataSource,
28-
isEnum,
29-
} from '@zenstackhq/sdk/ast';
18+
import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast';
3019
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
3120
import { promises as fs } from 'fs';
3221
import { streamAllContents } from 'langium';
@@ -271,18 +260,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
271260
overwrite: true,
272261
});
273262
sf.replaceWithText((writer) => {
274-
const fields = model.fields.filter(
263+
const scalarFields = model.fields.filter(
275264
(field) =>
276265
// regular fields only
277266
!isDataModel(field.type.reference?.ref) && !isForeignKeyField(field)
278267
);
279268

280269
const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref));
281270
const fkFields = model.fields.filter((field) => isForeignKeyField(field));
282-
// unsafe version of relations: including foreign keys and relation fields without fk
283-
const unsafeRelations = model.fields.filter(
284-
(field) => isForeignKeyField(field) || (isDataModel(field.type.reference?.ref) && !hasForeignKey(field))
285-
);
286271

287272
writer.writeLine('/* eslint-disable */');
288273
writer.writeLine(`import { z } from 'zod';`);
@@ -304,7 +289,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
304289

305290
// import enum schemas
306291
const importedEnumSchemas = new Set<string>();
307-
for (const field of fields) {
292+
for (const field of scalarFields) {
308293
if (field.type.reference?.ref && isEnum(field.type.reference?.ref)) {
309294
const name = upperCaseFirst(field.type.reference?.ref.name);
310295
if (!importedEnumSchemas.has(name)) {
@@ -315,29 +300,28 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
315300
}
316301

317302
// import Decimal
318-
if (fields.some((field) => field.type.type === 'Decimal')) {
303+
if (scalarFields.some((field) => field.type.type === 'Decimal')) {
319304
writer.writeLine(`import { DecimalSchema } from '../common';`);
320305
writer.writeLine(`import { Decimal } from 'decimal.js';`);
321306
}
322307

323308
// base schema
324309
writer.write(`const baseSchema = z.object(`);
325310
writer.inlineBlock(() => {
326-
fields.forEach((field) => {
311+
scalarFields.forEach((field) => {
327312
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
328313
});
329314
});
330315
writer.writeLine(');');
331316

332317
// relation fields
333318

334-
let allRelationSchema: string | undefined;
335-
let safeRelationSchema: string | undefined;
336-
let unsafeRelationSchema: string | undefined;
319+
let relationSchema: string | undefined;
320+
let fkSchema: string | undefined;
337321

338322
if (relations.length > 0 || fkFields.length > 0) {
339-
allRelationSchema = 'allRelationSchema';
340-
writer.write(`const ${allRelationSchema} = z.object(`);
323+
relationSchema = 'relationSchema';
324+
writer.write(`const ${relationSchema} = z.object(`);
341325
writer.inlineBlock(() => {
342326
[...relations, ...fkFields].forEach((field) => {
343327
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
@@ -346,23 +330,12 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
346330
writer.writeLine(');');
347331
}
348332

349-
if (relations.length > 0) {
350-
safeRelationSchema = 'safeRelationSchema';
351-
writer.write(`const ${safeRelationSchema} = z.object(`);
333+
if (fkFields.length > 0) {
334+
fkSchema = 'fkSchema';
335+
writer.write(`const ${fkSchema} = z.object(`);
352336
writer.inlineBlock(() => {
353-
relations.forEach((field) => {
354-
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
355-
});
356-
});
357-
writer.writeLine(');');
358-
}
359-
360-
if (unsafeRelations.length > 0) {
361-
unsafeRelationSchema = 'unsafeRelationSchema';
362-
writer.write(`const ${unsafeRelationSchema} = z.object(`);
363-
writer.inlineBlock(() => {
364-
unsafeRelations.forEach((field) => {
365-
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
337+
fkFields.forEach((field) => {
338+
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
366339
});
367340
});
368341
writer.writeLine(');');
@@ -383,25 +356,25 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
383356
////////////////////////////////////////////////
384357
// 1. Model schema
385358
////////////////////////////////////////////////
386-
let modelSchema = 'baseSchema';
359+
let modelSchema = makePartial('baseSchema');
387360

388361
// omit fields
389-
const fieldsToOmit = fields.filter((field) => hasAttribute(field, '@omit'));
362+
const fieldsToOmit = scalarFields.filter((field) => hasAttribute(field, '@omit'));
390363
if (fieldsToOmit.length > 0) {
391364
modelSchema = makeOmit(
392365
modelSchema,
393366
fieldsToOmit.map((f) => f.name)
394367
);
395368
}
396369

397-
if (allRelationSchema) {
370+
if (relationSchema) {
398371
// export schema with only scalar fields
399372
const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`;
400373
writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`);
401374
modelSchema = modelScalarSchema;
402375

403376
// merge relations
404-
modelSchema = makeMerge(modelSchema, allRelationSchema);
377+
modelSchema = makeMerge(modelSchema, makePartial(relationSchema));
405378
}
406379

407380
// refine
@@ -413,10 +386,40 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
413386
writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`);
414387

415388
////////////////////////////////////////////////
416-
// 2. Create schema
389+
// 2. Prisma create & update
390+
////////////////////////////////////////////////
391+
392+
// schema for validating prisma create input (all fields optional)
393+
let prismaCreateSchema = makePartial('baseSchema');
394+
if (refineFuncName) {
395+
prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`;
396+
}
397+
writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSchema};`);
398+
399+
// schema for validating prisma update input (all fields optional)
400+
// note numeric fields can be simple update or atomic operations
401+
let prismaUpdateSchema = `z.object({
402+
${scalarFields
403+
.map((field) => {
404+
let fieldSchema = makeFieldSchema(field);
405+
if (field.type.type === 'Int' || field.type.type === 'Float') {
406+
fieldSchema = `z.union([${fieldSchema}, z.record(z.unknown())])`;
407+
}
408+
return `\t${field.name}: ${fieldSchema}`;
409+
})
410+
.join(',\n')}
411+
})`;
412+
prismaUpdateSchema = makePartial(prismaUpdateSchema);
413+
if (refineFuncName) {
414+
prismaUpdateSchema = `${refineFuncName}(${prismaUpdateSchema})`;
415+
}
416+
writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSchema};`);
417+
418+
////////////////////////////////////////////////
419+
// 3. Create schema
417420
////////////////////////////////////////////////
418421
let createSchema = 'baseSchema';
419-
const fieldsWithDefault = fields.filter(
422+
const fieldsWithDefault = scalarFields.filter(
420423
(field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array
421424
);
422425
if (fieldsWithDefault.length > 0) {
@@ -426,30 +429,13 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
426429
);
427430
}
428431

429-
if (safeRelationSchema || unsafeRelationSchema) {
432+
if (fkSchema) {
430433
// export schema with only scalar fields
431434
const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`;
432435
writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`);
433-
createSchema = createScalarSchema;
434-
435-
if (safeRelationSchema && unsafeRelationSchema) {
436-
// build a union of with relation object fields and with fk fields (mutually exclusive)
437-
438-
// TODO: we make all relation fields partial for now because in case of
439-
// nested create, not all relation/fk fields are inside payload, need a
440-
// better solution
441-
createSchema = makeUnion(
442-
makeMerge(createSchema, makePartial(safeRelationSchema)),
443-
makeMerge(createSchema, makePartial(unsafeRelationSchema))
444-
);
445-
} else if (safeRelationSchema) {
446-
// just relation
447-
448-
// TODO: we make all relation fields partial for now because in case of
449-
// nested create, not all relation/fk fields are inside payload, need a
450-
// better solution
451-
createSchema = makeMerge(createSchema, makePartial(safeRelationSchema));
452-
}
436+
437+
// merge fk fields
438+
createSchema = makeMerge(createScalarSchema, fkSchema);
453439
}
454440

455441
if (refineFuncName) {
@@ -465,22 +451,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
465451
////////////////////////////////////////////////
466452
let updateSchema = makePartial('baseSchema');
467453

468-
if (safeRelationSchema || unsafeRelationSchema) {
454+
if (fkSchema) {
469455
// export schema with only scalar fields
470456
const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`;
471457
writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`);
472458
updateSchema = updateScalarSchema;
473459

474-
if (safeRelationSchema && unsafeRelationSchema) {
475-
// build a union of with relation object fields and with fk fields (mutually exclusive)
476-
updateSchema = makeUnion(
477-
makeMerge(updateSchema, makePartial(safeRelationSchema)),
478-
makeMerge(updateSchema, makePartial(unsafeRelationSchema))
479-
);
480-
} else if (safeRelationSchema) {
481-
// just relation
482-
updateSchema = makeMerge(updateSchema, makePartial(safeRelationSchema));
483-
}
460+
// merge fk fields
461+
updateSchema = makeMerge(updateSchema, makePartial(fkSchema));
484462
}
485463

486464
if (refineFuncName) {
@@ -514,15 +492,3 @@ function makeOmit(schema: string, fields: string[]) {
514492
function makeMerge(schema1: string, schema2: string): string {
515493
return `${schema1}.merge(${schema2})`;
516494
}
517-
518-
function makeUnion(...schemas: string[]): string {
519-
return `z.union([${schemas.join(', ')}])`;
520-
}
521-
522-
function hasForeignKey(field: DataModelField) {
523-
const relAttr = getAttribute(field, '@relation');
524-
if (!relAttr) {
525-
return false;
526-
}
527-
return !!getAttributeArg(relAttr, 'fields');
528-
}

0 commit comments

Comments
 (0)