Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions packages/plugins/react/src/generator/react-query.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { DMMF } from '@prisma/generator-helper';
import { PluginOptions, createProject, getDataModels, saveProject } from '@zenstackhq/sdk';
import { PluginOptions, createProject, getDataModels, getPrismaClientImportSpec, saveProject } from '@zenstackhq/sdk';
import { DataModel, Model } from '@zenstackhq/sdk/ast';
import { requireOption, resolvePath } from '@zenstackhq/sdk/utils';
import { paramCase } from 'change-case';
Expand Down Expand Up @@ -147,14 +147,18 @@ function generateMutationHook(

function generateModelHooks(project: Project, outDir: string, model: DataModel, mapping: DMMF.ModelMapping) {
const fileName = paramCase(model.name);
const sf = project.createSourceFile(path.join(outDir, `${fileName}.ts`), undefined, { overwrite: true });

const hooksFile = path.resolve(outDir, `${fileName}.ts`);
const sf = project.createSourceFile(hooksFile, undefined, { overwrite: true });

sf.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(model.$container, path.dirname(hooksFile));

sf.addImportDeclaration({
namedImports: ['Prisma', model.name],
isTypeOnly: true,
moduleSpecifier: '@prisma/client',
moduleSpecifier: prismaImport,
});
sf.addStatements([
`import { useContext } from 'react';`,
Expand Down
4 changes: 3 additions & 1 deletion packages/plugins/react/src/generator/swr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
PluginOptions,
createProject,
getDataModels,
getPrismaClientImportSpec,
requireOption,
resolvePath,
saveProject,
Expand Down Expand Up @@ -56,10 +57,11 @@ function generateModelHooks(project: Project, outDir: string, model: DataModel,

sf.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(model.$container, outDir);
sf.addImportDeclaration({
namedImports: ['Prisma', model.name],
isTypeOnly: true,
moduleSpecifier: '@prisma/client',
moduleSpecifier: prismaImport,
});
sf.addStatements([
`import { useContext } from 'react';`,
Expand Down
4 changes: 3 additions & 1 deletion packages/plugins/swr/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
PluginOptions,
createProject,
getDataModels,
getPrismaClientImportSpec,
requireOption,
resolvePath,
saveProject,
Expand Down Expand Up @@ -59,10 +60,11 @@ function generateModelHooks(project: Project, outDir: string, model: DataModel,

sf.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(model.$container, outDir);
sf.addImportDeclaration({
namedImports: ['Prisma', model.name],
isTypeOnly: true,
moduleSpecifier: '@prisma/client',
moduleSpecifier: prismaImport,
});
sf.addStatements([
`import { useContext } from 'react';`,
Expand Down
4 changes: 3 additions & 1 deletion packages/plugins/tanstack-query/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
PluginOptions,
createProject,
getDataModels,
getPrismaClientImportSpec,
requireOption,
resolvePath,
saveProject,
Expand Down Expand Up @@ -218,10 +219,11 @@ function generateModelHooks(

sf.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(model.$container, outDir);
sf.addImportDeclaration({
namedImports: ['Prisma', model.name],
isTypeOnly: true,
moduleSpecifier: '@prisma/client',
moduleSpecifier: prismaImport,
});
sf.addStatements(makeBaseImports(target));

Expand Down
20 changes: 13 additions & 7 deletions packages/plugins/trpc/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
PluginError,
PluginOptions,
RUNTIME_PACKAGE,
getPrismaClientImportSpec,
requireOption,
resolvePath,
saveProject,
Expand Down Expand Up @@ -53,7 +54,7 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
const hiddenModels: string[] = [];
resolveModelsComments(models, hiddenModels);

createAppRouter(outDir, modelOperations, hiddenModels, generateModelActions, generateClientHelpers);
createAppRouter(outDir, modelOperations, hiddenModels, generateModelActions, generateClientHelpers, model);
createHelper(outDir);

await saveProject(project);
Expand All @@ -64,22 +65,25 @@ function createAppRouter(
modelOperations: DMMF.ModelMapping[],
hiddenModels: string[],
generateModelActions: string[] | undefined,
generateClientHelpers: string[] | undefined
generateClientHelpers: string[] | undefined,
zmodel: Model
) {
const appRouter = project.createSourceFile(path.resolve(outDir, 'routers', `index.ts`), undefined, {
const indexFile = path.resolve(outDir, 'routers', `index.ts`);
const appRouter = project.createSourceFile(indexFile, undefined, {
overwrite: true,
});

appRouter.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(zmodel, path.dirname(indexFile));
appRouter.addImportDeclarations([
{
namedImports: ['AnyRootConfig'],
moduleSpecifier: '@trpc/server',
},
{
namedImports: ['PrismaClient'],
moduleSpecifier: '@prisma/client',
moduleSpecifier: prismaImport,
},
{
namedImports: ['createRouterFactory', 'AnyRouter'],
Expand Down Expand Up @@ -133,7 +137,8 @@ function createAppRouter(
operations,
outDir,
generateModelActions,
generateClientHelpers
generateClientHelpers,
zmodel
);

appRouter.addImportDeclaration({
Expand Down Expand Up @@ -201,7 +206,8 @@ function generateModelCreateRouter(
operations: Record<string, string | undefined | null>,
outputDir: string,
generateModelActions: string[] | undefined,
generateClientHelpers: string[] | undefined
generateClientHelpers: string[] | undefined,
zmodel: Model
) {
const modelRouter = project.createSourceFile(path.resolve(outputDir, 'routers', `${model}.router.ts`), undefined, {
overwrite: true,
Expand All @@ -219,7 +225,7 @@ function generateModelCreateRouter(
generateRouterSchemaImports(modelRouter, model);
generateHelperImport(modelRouter);
if (generateClientHelpers) {
generateRouterTypingImports(modelRouter);
generateRouterTypingImports(modelRouter, zmodel);
}

const createRouterFunc = modelRouter.addFunction({
Expand Down
9 changes: 6 additions & 3 deletions packages/plugins/trpc/src/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { DMMF } from '@prisma/generator-helper';
import { PluginError } from '@zenstackhq/sdk';
import { PluginError, getPrismaClientImportSpec } from '@zenstackhq/sdk';
import { CodeBlockWriter, SourceFile } from 'ts-morph';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '.';
import { uncapitalizeFirstLetter } from './utils/uncapitalizeFirstLetter';
import { Model } from '@zenstackhq/sdk/ast';

export function generateProcedure(
writer: CodeBlockWriter,
Expand Down Expand Up @@ -226,9 +227,11 @@ export function generateRouterTyping(writer: CodeBlockWriter, opType: string, mo
});
}

export function generateRouterTypingImports(sourceFile: SourceFile) {
export function generateRouterTypingImports(sourceFile: SourceFile, model: Model) {
const importingDir = sourceFile.getDirectoryPath();
const prismaImport = getPrismaClientImportSpec(model, importingDir);
sourceFile.addStatements([
`import type { Prisma } from '@prisma/client';`,
`import type { Prisma } from '${prismaImport}';`,
`import type { UseTRPCMutationOptions, UseTRPCMutationResult, UseTRPCQueryOptions, UseTRPCQueryResult, UseTRPCInfiniteQueryOptions, UseTRPCInfiniteQueryResult } from '@trpc/react-query/shared';`,
`import type { TRPCClientErrorLike } from '@trpc/client';`,
`import type { AnyRouter } from '@trpc/server';`,
Expand Down
25 changes: 18 additions & 7 deletions packages/plugins/trpc/src/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma;
const models: DMMF.Model[] = prismaClientDmmf.datamodel.models;

await generateEnumSchemas(prismaClientDmmf.schema.enumTypes.prisma, prismaClientDmmf.schema.enumTypes.model ?? []);
await generateEnumSchemas(
prismaClientDmmf.schema.enumTypes.prisma,
prismaClientDmmf.schema.enumTypes.model ?? [],
model
);

const dataSource = model.declarations.find((d): d is DataSource => isDataSource(d));

Expand All @@ -43,8 +47,8 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.

const aggregateOperationSupport = resolveAggregateOperationSupport(inputObjectTypes);

await generateObjectSchemas(inputObjectTypes, output);
await generateModelSchemas(models, modelOperations, aggregateOperationSupport);
await generateObjectSchemas(inputObjectTypes, output, model);
await generateModelSchemas(models, modelOperations, aggregateOperationSupport, model);
}

async function handleGeneratorOutputValue(output: string) {
Expand All @@ -56,22 +60,27 @@ async function handleGeneratorOutputValue(output: string) {
Transformer.setOutputPath(output);
}

async function generateEnumSchemas(prismaSchemaEnum: DMMF.SchemaEnum[], modelSchemaEnum: DMMF.SchemaEnum[]) {
async function generateEnumSchemas(
prismaSchemaEnum: DMMF.SchemaEnum[],
modelSchemaEnum: DMMF.SchemaEnum[],
zmodel: Model
) {
const enumTypes = [...prismaSchemaEnum, ...modelSchemaEnum];
const enumNames = enumTypes.map((enumItem) => enumItem.name);
Transformer.enumNames = enumNames ?? [];
const transformer = new Transformer({
enumTypes,
zmodel,
});
await transformer.generateEnumSchemas();
}

async function generateObjectSchemas(inputObjectTypes: DMMF.InputType[], output: string) {
async function generateObjectSchemas(inputObjectTypes: DMMF.InputType[], output: string, zmodel: Model) {
const moduleNames: string[] = [];
for (let i = 0; i < inputObjectTypes.length; i += 1) {
const fields = inputObjectTypes[i]?.fields;
const name = inputObjectTypes[i]?.name;
const transformer = new Transformer({ name, fields });
const transformer = new Transformer({ name, fields, zmodel });
const moduleName = await transformer.generateObjectSchema();
moduleNames.push(moduleName);
}
Expand All @@ -84,12 +93,14 @@ async function generateObjectSchemas(inputObjectTypes: DMMF.InputType[], output:
async function generateModelSchemas(
models: DMMF.Model[],
modelOperations: DMMF.ModelMapping[],
aggregateOperationSupport: AggregateOperationSupport
aggregateOperationSupport: AggregateOperationSupport,
zmodel: Model
) {
const transformer = new Transformer({
models,
modelOperations,
aggregateOperationSupport,
zmodel,
});
await transformer.generateModelSchemas();
}
35 changes: 6 additions & 29 deletions packages/plugins/trpc/src/zod/transformer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import type { DMMF as PrismaDMMF } from '@prisma/generator-helper';
import { AUXILIARY_FIELDS } from '@zenstackhq/sdk';
import { AUXILIARY_FIELDS, getPrismaClientImportSpec } from '@zenstackhq/sdk';
import { Model } from '@zenstackhq/sdk/ast';
import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers';
import indentString from '@zenstackhq/sdk/utils';
import path from 'path';
Expand All @@ -21,8 +22,7 @@ export default class Transformer {
static provider: string;
private static outputPath = './generated';
private hasJson = false;
private static prismaClientOutputPath = '@prisma/client';
private static isCustomPrismaClientOutputPath = false;
private zmodel: Model;

constructor(params: TransformerParams) {
this.name = params.name ?? '';
Expand All @@ -31,6 +31,7 @@ export default class Transformer {
this.modelOperations = params.modelOperations ?? [];
this.aggregateOperationSupport = params.aggregateOperationSupport ?? {};
this.enumTypes = params.enumTypes ?? [];
this.zmodel = params.zmodel;
}

static setOutputPath(outPath: string) {
Expand All @@ -41,11 +42,6 @@ export default class Transformer {
return this.outputPath;
}

static setPrismaClientOutputPath(prismaClientCustomPath: string) {
this.prismaClientOutputPath = prismaClientCustomPath;
this.isCustomPrismaClientOutputPath = prismaClientCustomPath !== '@prisma/client';
}

async generateEnumSchemas() {
for (const enumType of this.enumTypes) {
const { name, values } = enumType;
Expand Down Expand Up @@ -268,27 +264,8 @@ export default class Transformer {
}

generateImportPrismaStatement() {
let prismaClientImportPath: string;
if (Transformer.isCustomPrismaClientOutputPath) {
/**
* If a custom location was designated for the prisma client, we need to figure out the
* relative path from {outputPath}/schemas/objects to {prismaClientCustomPath}
*/
const fromPath = path.join(Transformer.outputPath, 'schemas', 'objects');
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const toPath = Transformer.prismaClientOutputPath!;
const relativePathFromOutputToPrismaClient = path
.relative(fromPath, toPath)
.split(path.sep)
.join(path.posix.sep);
prismaClientImportPath = relativePathFromOutputToPrismaClient;
} else {
/**
* If the default output path for prisma client (@prisma/client) is being used, we can import from it directly
* without having to resolve a relative path
*/
prismaClientImportPath = Transformer.prismaClientOutputPath;
}
const importingFrom = path.resolve(Transformer.outputPath, 'schemas', 'objects');
const prismaClientImportPath = getPrismaClientImportSpec(this.zmodel, importingFrom);
return `import type { Prisma } from '${prismaClientImportPath}';\n\n`;
}

Expand Down
2 changes: 2 additions & 0 deletions packages/plugins/trpc/src/zod/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { DMMF as PrismaDMMF } from '@prisma/generator-helper';
import { Model } from '@zenstackhq/sdk/ast';

export type TransformerParams = {
enumTypes?: PrismaDMMF.SchemaEnum[];
Expand All @@ -9,6 +10,7 @@ export type TransformerParams = {
aggregateOperationSupport?: AggregateOperationSupport;
isDefaultPrismaClientOutput?: boolean;
prismaClientOutputPath?: string;
zmodel: Model;
};

export type AggregateOperationSupport = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
emitProject,
getDataModels,
getLiteral,
getPrismaClientImportSpec,
GUARD_FIELD_NAME,
PluginError,
PluginOptions,
Expand Down Expand Up @@ -63,10 +64,11 @@ export default class PolicyGenerator {
});

// import enums
const prismaImport = getPrismaClientImportSpec(model, output);
for (const e of model.declarations.filter((d) => isEnum(d))) {
sf.addImportDeclaration({
namedImports: [{ name: e.name }],
moduleSpecifier: '@prisma/client',
moduleSpecifier: prismaImport,
});
}

Expand Down
Loading