diff --git a/packages/plugins/react/src/generator/react-query.ts b/packages/plugins/react/src/generator/react-query.ts index 0a9d08481..6cbeb6f13 100644 --- a/packages/plugins/react/src/generator/react-query.ts +++ b/packages/plugins/react/src/generator/react-query.ts @@ -1,5 +1,5 @@ import { DMMF } from '@prisma/generator-helper'; -import { PluginOptions, createProject, getDataModels, saveProject } from '@zenstackhq/sdk'; +import { PluginOptions, createProject, getDataModels, getPrismaClientImportSpec, saveProject } from '@zenstackhq/sdk'; import { DataModel, Model } from '@zenstackhq/sdk/ast'; import { requireOption, resolvePath } from '@zenstackhq/sdk/utils'; import { paramCase } from 'change-case'; @@ -147,14 +147,18 @@ function generateMutationHook( function generateModelHooks(project: Project, outDir: string, model: DataModel, mapping: DMMF.ModelMapping) { const fileName = paramCase(model.name); - const sf = project.createSourceFile(path.join(outDir, `${fileName}.ts`), undefined, { overwrite: true }); + + const hooksFile = path.resolve(outDir, `${fileName}.ts`); + const sf = project.createSourceFile(hooksFile, undefined, { overwrite: true }); sf.addStatements('/* eslint-disable */'); + const prismaImport = getPrismaClientImportSpec(model.$container, path.dirname(hooksFile)); + sf.addImportDeclaration({ namedImports: ['Prisma', model.name], isTypeOnly: true, - moduleSpecifier: '@prisma/client', + moduleSpecifier: prismaImport, }); sf.addStatements([ `import { useContext } from 'react';`, diff --git a/packages/plugins/react/src/generator/swr.ts b/packages/plugins/react/src/generator/swr.ts index 766b3e853..65ffa06d2 100644 --- a/packages/plugins/react/src/generator/swr.ts +++ b/packages/plugins/react/src/generator/swr.ts @@ -4,6 +4,7 @@ import { PluginOptions, createProject, getDataModels, + getPrismaClientImportSpec, requireOption, resolvePath, saveProject, @@ -56,10 +57,11 @@ function generateModelHooks(project: Project, outDir: string, model: DataModel, sf.addStatements('/* eslint-disable */'); + const prismaImport = getPrismaClientImportSpec(model.$container, outDir); sf.addImportDeclaration({ namedImports: ['Prisma', model.name], isTypeOnly: true, - moduleSpecifier: '@prisma/client', + moduleSpecifier: prismaImport, }); sf.addStatements([ `import { useContext } from 'react';`, diff --git a/packages/plugins/swr/src/generator.ts b/packages/plugins/swr/src/generator.ts index e7ba5e555..8cd3a3fac 100644 --- a/packages/plugins/swr/src/generator.ts +++ b/packages/plugins/swr/src/generator.ts @@ -4,6 +4,7 @@ import { PluginOptions, createProject, getDataModels, + getPrismaClientImportSpec, requireOption, resolvePath, saveProject, @@ -59,10 +60,11 @@ function generateModelHooks(project: Project, outDir: string, model: DataModel, sf.addStatements('/* eslint-disable */'); + const prismaImport = getPrismaClientImportSpec(model.$container, outDir); sf.addImportDeclaration({ namedImports: ['Prisma', model.name], isTypeOnly: true, - moduleSpecifier: '@prisma/client', + moduleSpecifier: prismaImport, }); sf.addStatements([ `import { useContext } from 'react';`, diff --git a/packages/plugins/tanstack-query/src/generator.ts b/packages/plugins/tanstack-query/src/generator.ts index 49301061d..5da93737d 100644 --- a/packages/plugins/tanstack-query/src/generator.ts +++ b/packages/plugins/tanstack-query/src/generator.ts @@ -4,6 +4,7 @@ import { PluginOptions, createProject, getDataModels, + getPrismaClientImportSpec, requireOption, resolvePath, saveProject, @@ -218,10 +219,11 @@ function generateModelHooks( sf.addStatements('/* eslint-disable */'); + const prismaImport = getPrismaClientImportSpec(model.$container, outDir); sf.addImportDeclaration({ namedImports: ['Prisma', model.name], isTypeOnly: true, - moduleSpecifier: '@prisma/client', + moduleSpecifier: prismaImport, }); sf.addStatements(makeBaseImports(target)); diff --git a/packages/plugins/trpc/src/generator.ts b/packages/plugins/trpc/src/generator.ts index c43db83bc..1ffdeaafb 100644 --- a/packages/plugins/trpc/src/generator.ts +++ b/packages/plugins/trpc/src/generator.ts @@ -4,6 +4,7 @@ import { PluginError, PluginOptions, RUNTIME_PACKAGE, + getPrismaClientImportSpec, requireOption, resolvePath, saveProject, @@ -53,7 +54,7 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. const hiddenModels: string[] = []; resolveModelsComments(models, hiddenModels); - createAppRouter(outDir, modelOperations, hiddenModels, generateModelActions, generateClientHelpers); + createAppRouter(outDir, modelOperations, hiddenModels, generateModelActions, generateClientHelpers, model); createHelper(outDir); await saveProject(project); @@ -64,14 +65,17 @@ function createAppRouter( modelOperations: DMMF.ModelMapping[], hiddenModels: string[], generateModelActions: string[] | undefined, - generateClientHelpers: string[] | undefined + generateClientHelpers: string[] | undefined, + zmodel: Model ) { - const appRouter = project.createSourceFile(path.resolve(outDir, 'routers', `index.ts`), undefined, { + const indexFile = path.resolve(outDir, 'routers', `index.ts`); + const appRouter = project.createSourceFile(indexFile, undefined, { overwrite: true, }); appRouter.addStatements('/* eslint-disable */'); + const prismaImport = getPrismaClientImportSpec(zmodel, path.dirname(indexFile)); appRouter.addImportDeclarations([ { namedImports: ['AnyRootConfig'], @@ -79,7 +83,7 @@ function createAppRouter( }, { namedImports: ['PrismaClient'], - moduleSpecifier: '@prisma/client', + moduleSpecifier: prismaImport, }, { namedImports: ['createRouterFactory', 'AnyRouter'], @@ -133,7 +137,8 @@ function createAppRouter( operations, outDir, generateModelActions, - generateClientHelpers + generateClientHelpers, + zmodel ); appRouter.addImportDeclaration({ @@ -201,7 +206,8 @@ function generateModelCreateRouter( operations: Record, outputDir: string, generateModelActions: string[] | undefined, - generateClientHelpers: string[] | undefined + generateClientHelpers: string[] | undefined, + zmodel: Model ) { const modelRouter = project.createSourceFile(path.resolve(outputDir, 'routers', `${model}.router.ts`), undefined, { overwrite: true, @@ -219,7 +225,7 @@ function generateModelCreateRouter( generateRouterSchemaImports(modelRouter, model); generateHelperImport(modelRouter); if (generateClientHelpers) { - generateRouterTypingImports(modelRouter); + generateRouterTypingImports(modelRouter, zmodel); } const createRouterFunc = modelRouter.addFunction({ diff --git a/packages/plugins/trpc/src/helpers.ts b/packages/plugins/trpc/src/helpers.ts index 41fc3f2d1..0df70aae5 100644 --- a/packages/plugins/trpc/src/helpers.ts +++ b/packages/plugins/trpc/src/helpers.ts @@ -1,9 +1,10 @@ import { DMMF } from '@prisma/generator-helper'; -import { PluginError } from '@zenstackhq/sdk'; +import { PluginError, getPrismaClientImportSpec } from '@zenstackhq/sdk'; import { CodeBlockWriter, SourceFile } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; import { uncapitalizeFirstLetter } from './utils/uncapitalizeFirstLetter'; +import { Model } from '@zenstackhq/sdk/ast'; export function generateProcedure( writer: CodeBlockWriter, @@ -226,9 +227,11 @@ export function generateRouterTyping(writer: CodeBlockWriter, opType: string, mo }); } -export function generateRouterTypingImports(sourceFile: SourceFile) { +export function generateRouterTypingImports(sourceFile: SourceFile, model: Model) { + const importingDir = sourceFile.getDirectoryPath(); + const prismaImport = getPrismaClientImportSpec(model, importingDir); sourceFile.addStatements([ - `import type { Prisma } from '@prisma/client';`, + `import type { Prisma } from '${prismaImport}';`, `import type { UseTRPCMutationOptions, UseTRPCMutationResult, UseTRPCQueryOptions, UseTRPCQueryResult, UseTRPCInfiniteQueryOptions, UseTRPCInfiniteQueryResult } from '@trpc/react-query/shared';`, `import type { TRPCClientErrorLike } from '@trpc/client';`, `import type { AnyRouter } from '@trpc/server';`, diff --git a/packages/plugins/trpc/src/zod/generator.ts b/packages/plugins/trpc/src/zod/generator.ts index daa5064a9..72b0cf335 100644 --- a/packages/plugins/trpc/src/zod/generator.ts +++ b/packages/plugins/trpc/src/zod/generator.ts @@ -26,7 +26,11 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma; const models: DMMF.Model[] = prismaClientDmmf.datamodel.models; - await generateEnumSchemas(prismaClientDmmf.schema.enumTypes.prisma, prismaClientDmmf.schema.enumTypes.model ?? []); + await generateEnumSchemas( + prismaClientDmmf.schema.enumTypes.prisma, + prismaClientDmmf.schema.enumTypes.model ?? [], + model + ); const dataSource = model.declarations.find((d): d is DataSource => isDataSource(d)); @@ -43,8 +47,8 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. const aggregateOperationSupport = resolveAggregateOperationSupport(inputObjectTypes); - await generateObjectSchemas(inputObjectTypes, output); - await generateModelSchemas(models, modelOperations, aggregateOperationSupport); + await generateObjectSchemas(inputObjectTypes, output, model); + await generateModelSchemas(models, modelOperations, aggregateOperationSupport, model); } async function handleGeneratorOutputValue(output: string) { @@ -56,22 +60,27 @@ async function handleGeneratorOutputValue(output: string) { Transformer.setOutputPath(output); } -async function generateEnumSchemas(prismaSchemaEnum: DMMF.SchemaEnum[], modelSchemaEnum: DMMF.SchemaEnum[]) { +async function generateEnumSchemas( + prismaSchemaEnum: DMMF.SchemaEnum[], + modelSchemaEnum: DMMF.SchemaEnum[], + zmodel: Model +) { const enumTypes = [...prismaSchemaEnum, ...modelSchemaEnum]; const enumNames = enumTypes.map((enumItem) => enumItem.name); Transformer.enumNames = enumNames ?? []; const transformer = new Transformer({ enumTypes, + zmodel, }); await transformer.generateEnumSchemas(); } -async function generateObjectSchemas(inputObjectTypes: DMMF.InputType[], output: string) { +async function generateObjectSchemas(inputObjectTypes: DMMF.InputType[], output: string, zmodel: Model) { const moduleNames: string[] = []; for (let i = 0; i < inputObjectTypes.length; i += 1) { const fields = inputObjectTypes[i]?.fields; const name = inputObjectTypes[i]?.name; - const transformer = new Transformer({ name, fields }); + const transformer = new Transformer({ name, fields, zmodel }); const moduleName = await transformer.generateObjectSchema(); moduleNames.push(moduleName); } @@ -84,12 +93,14 @@ async function generateObjectSchemas(inputObjectTypes: DMMF.InputType[], output: async function generateModelSchemas( models: DMMF.Model[], modelOperations: DMMF.ModelMapping[], - aggregateOperationSupport: AggregateOperationSupport + aggregateOperationSupport: AggregateOperationSupport, + zmodel: Model ) { const transformer = new Transformer({ models, modelOperations, aggregateOperationSupport, + zmodel, }); await transformer.generateModelSchemas(); } diff --git a/packages/plugins/trpc/src/zod/transformer.ts b/packages/plugins/trpc/src/zod/transformer.ts index 7b43befac..585cab5de 100644 --- a/packages/plugins/trpc/src/zod/transformer.ts +++ b/packages/plugins/trpc/src/zod/transformer.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ import type { DMMF as PrismaDMMF } from '@prisma/generator-helper'; -import { AUXILIARY_FIELDS } from '@zenstackhq/sdk'; +import { AUXILIARY_FIELDS, getPrismaClientImportSpec } from '@zenstackhq/sdk'; +import { Model } from '@zenstackhq/sdk/ast'; import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers'; import indentString from '@zenstackhq/sdk/utils'; import path from 'path'; @@ -21,8 +22,7 @@ export default class Transformer { static provider: string; private static outputPath = './generated'; private hasJson = false; - private static prismaClientOutputPath = '@prisma/client'; - private static isCustomPrismaClientOutputPath = false; + private zmodel: Model; constructor(params: TransformerParams) { this.name = params.name ?? ''; @@ -31,6 +31,7 @@ export default class Transformer { this.modelOperations = params.modelOperations ?? []; this.aggregateOperationSupport = params.aggregateOperationSupport ?? {}; this.enumTypes = params.enumTypes ?? []; + this.zmodel = params.zmodel; } static setOutputPath(outPath: string) { @@ -41,11 +42,6 @@ export default class Transformer { return this.outputPath; } - static setPrismaClientOutputPath(prismaClientCustomPath: string) { - this.prismaClientOutputPath = prismaClientCustomPath; - this.isCustomPrismaClientOutputPath = prismaClientCustomPath !== '@prisma/client'; - } - async generateEnumSchemas() { for (const enumType of this.enumTypes) { const { name, values } = enumType; @@ -268,27 +264,8 @@ export default class Transformer { } generateImportPrismaStatement() { - let prismaClientImportPath: string; - if (Transformer.isCustomPrismaClientOutputPath) { - /** - * If a custom location was designated for the prisma client, we need to figure out the - * relative path from {outputPath}/schemas/objects to {prismaClientCustomPath} - */ - const fromPath = path.join(Transformer.outputPath, 'schemas', 'objects'); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const toPath = Transformer.prismaClientOutputPath!; - const relativePathFromOutputToPrismaClient = path - .relative(fromPath, toPath) - .split(path.sep) - .join(path.posix.sep); - prismaClientImportPath = relativePathFromOutputToPrismaClient; - } else { - /** - * If the default output path for prisma client (@prisma/client) is being used, we can import from it directly - * without having to resolve a relative path - */ - prismaClientImportPath = Transformer.prismaClientOutputPath; - } + const importingFrom = path.resolve(Transformer.outputPath, 'schemas', 'objects'); + const prismaClientImportPath = getPrismaClientImportSpec(this.zmodel, importingFrom); return `import type { Prisma } from '${prismaClientImportPath}';\n\n`; } diff --git a/packages/plugins/trpc/src/zod/types.ts b/packages/plugins/trpc/src/zod/types.ts index a02b9ca7c..2e5a8e624 100644 --- a/packages/plugins/trpc/src/zod/types.ts +++ b/packages/plugins/trpc/src/zod/types.ts @@ -1,4 +1,5 @@ import { DMMF as PrismaDMMF } from '@prisma/generator-helper'; +import { Model } from '@zenstackhq/sdk/ast'; export type TransformerParams = { enumTypes?: PrismaDMMF.SchemaEnum[]; @@ -9,6 +10,7 @@ export type TransformerParams = { aggregateOperationSupport?: AggregateOperationSupport; isDefaultPrismaClientOutput?: boolean; prismaClientOutputPath?: string; + zmodel: Model; }; export type AggregateOperationSupport = { diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index a25d17458..00a325a3b 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -21,6 +21,7 @@ import { emitProject, getDataModels, getLiteral, + getPrismaClientImportSpec, GUARD_FIELD_NAME, PluginError, PluginOptions, @@ -63,10 +64,11 @@ export default class PolicyGenerator { }); // import enums + const prismaImport = getPrismaClientImportSpec(model, output); for (const e of model.declarations.filter((d) => isEnum(d))) { sf.addImportDeclaration({ namedImports: [{ name: e.name }], - moduleSpecifier: '@prisma/client', + moduleSpecifier: prismaImport, }); } diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index d723533eb..5757317de 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -39,7 +39,8 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. await generateEnumSchemas( prismaClientDmmf.schema.enumTypes.prisma, prismaClientDmmf.schema.enumTypes.model ?? [], - project + project, + model ); const dataSource = model.declarations.find((d): d is DataSource => isDataSource(d)); @@ -57,8 +58,8 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. const aggregateOperationSupport = resolveAggregateOperationSupport(inputObjectTypes); - await generateObjectSchemas(inputObjectTypes, project, output); - await generateModelSchemas(models, modelOperations, aggregateOperationSupport, project); + await generateObjectSchemas(inputObjectTypes, project, output, model); + await generateModelSchemas(models, modelOperations, aggregateOperationSupport, project, model); const shouldCompile = options.compile !== false; if (!shouldCompile || options.preserveTsFiles === true) { @@ -82,7 +83,8 @@ async function handleGeneratorOutputValue(output: string) { async function generateEnumSchemas( prismaSchemaEnum: DMMF.SchemaEnum[], modelSchemaEnum: DMMF.SchemaEnum[], - project: Project + project: Project, + zmodel: Model ) { const enumTypes = [...prismaSchemaEnum, ...modelSchemaEnum]; const enumNames = enumTypes.map((enumItem) => enumItem.name); @@ -90,16 +92,22 @@ async function generateEnumSchemas( const transformer = new Transformer({ enumTypes, project, + zmodel, }); await transformer.generateEnumSchemas(); } -async function generateObjectSchemas(inputObjectTypes: DMMF.InputType[], project: Project, output: string) { +async function generateObjectSchemas( + inputObjectTypes: DMMF.InputType[], + project: Project, + output: string, + zmodel: Model +) { const moduleNames: string[] = []; for (let i = 0; i < inputObjectTypes.length; i += 1) { const fields = inputObjectTypes[i]?.fields; const name = inputObjectTypes[i]?.name; - const transformer = new Transformer({ name, fields, project }); + const transformer = new Transformer({ name, fields, project, zmodel }); const moduleName = transformer.generateObjectSchema(); moduleNames.push(moduleName); } @@ -114,13 +122,15 @@ async function generateModelSchemas( models: DMMF.Model[], modelOperations: DMMF.ModelMapping[], aggregateOperationSupport: AggregateOperationSupport, - project: Project + project: Project, + zmodel: Model ) { const transformer = new Transformer({ models, modelOperations, aggregateOperationSupport, project, + zmodel, }); await transformer.generateModelSchemas(); } diff --git a/packages/schema/src/plugins/zod/transformer.ts b/packages/schema/src/plugins/zod/transformer.ts index 325aba5b4..723e8b295 100644 --- a/packages/schema/src/plugins/zod/transformer.ts +++ b/packages/schema/src/plugins/zod/transformer.ts @@ -1,11 +1,12 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ import type { DMMF as PrismaDMMF } from '@prisma/generator-helper'; -import { AUXILIARY_FIELDS } from '@zenstackhq/sdk'; +import { Model } from '@zenstackhq/language/ast'; +import { AUXILIARY_FIELDS, getPrismaClientImportSpec } from '@zenstackhq/sdk'; import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers'; import indentString from '@zenstackhq/sdk/utils'; -import { upperCaseFirst } from 'upper-case-first'; import path from 'path'; import { Project } from 'ts-morph'; +import { upperCaseFirst } from 'upper-case-first'; import { AggregateOperationSupport, TransformerParams } from './types'; export default class Transformer { @@ -22,9 +23,8 @@ export default class Transformer { static provider: string; private static outputPath = './generated'; private hasJson = false; - private static prismaClientOutputPath = '@prisma/client'; - private static isCustomPrismaClientOutputPath = false; private project: Project; + private zmodel: Model; constructor(params: TransformerParams) { this.name = params.name ?? ''; @@ -34,6 +34,7 @@ export default class Transformer { this.aggregateOperationSupport = params.aggregateOperationSupport ?? {}; this.enumTypes = params.enumTypes ?? []; this.project = params.project; + this.zmodel = params.zmodel; } static setOutputPath(outPath: string) { @@ -44,11 +45,6 @@ export default class Transformer { return this.outputPath; } - static setPrismaClientOutputPath(prismaClientCustomPath: string) { - this.prismaClientOutputPath = prismaClientCustomPath; - this.isCustomPrismaClientOutputPath = prismaClientCustomPath !== '@prisma/client'; - } - async generateEnumSchemas() { for (const enumType of this.enumTypes) { const { name, values } = enumType; @@ -270,27 +266,10 @@ export default class Transformer { } generateImportPrismaStatement() { - let prismaClientImportPath: string; - if (Transformer.isCustomPrismaClientOutputPath) { - /** - * If a custom location was designated for the prisma client, we need to figure out the - * relative path from {outputPath}/objects to {prismaClientCustomPath} - */ - const fromPath = path.join(Transformer.outputPath, 'objects'); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const toPath = Transformer.prismaClientOutputPath!; - const relativePathFromOutputToPrismaClient = path - .relative(fromPath, toPath) - .split(path.sep) - .join(path.posix.sep); - prismaClientImportPath = relativePathFromOutputToPrismaClient; - } else { - /** - * If the default output path for prisma client (@prisma/client) is being used, we can import from it directly - * without having to resolve a relative path - */ - prismaClientImportPath = Transformer.prismaClientOutputPath; - } + const prismaClientImportPath = getPrismaClientImportSpec( + this.zmodel, + path.resolve(Transformer.outputPath, './objects') + ); return `import type { Prisma } from '${prismaClientImportPath}';\n\n`; } diff --git a/packages/schema/src/plugins/zod/types.ts b/packages/schema/src/plugins/zod/types.ts index 49c35c023..33a377a29 100644 --- a/packages/schema/src/plugins/zod/types.ts +++ b/packages/schema/src/plugins/zod/types.ts @@ -1,4 +1,5 @@ import { DMMF as PrismaDMMF } from '@prisma/generator-helper'; +import { Model } from '@zenstackhq/language/ast'; import { Project } from 'ts-morph'; export type TransformerParams = { @@ -11,6 +12,7 @@ export type TransformerParams = { isDefaultPrismaClientOutput?: boolean; prismaClientOutputPath?: string; project: Project; + zmodel: Model; }; export type AggregateOperationSupport = { diff --git a/packages/schema/tests/cli/plugins.test.ts b/packages/schema/tests/cli/plugins.test.ts index 4d1520f2e..c99637f73 100644 --- a/packages/schema/tests/cli/plugins.test.ts +++ b/packages/schema/tests/cli/plugins.test.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-var-requires */ /// -import { getWorkspaceNpmCacheFolder } from '@zenstackhq/testtools'; +import { getWorkspaceNpmCacheFolder, run } from '@zenstackhq/testtools'; import * as fs from 'fs'; import * as path from 'path'; import * as tmp from 'tmp'; @@ -25,6 +25,22 @@ describe('CLI Plugins Tests', () => { fs.writeFileSync('.npmrc', `cache=${getWorkspaceNpmCacheFolder(__dirname)}`); } + async function initProject() { + fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); + createNpmrc(); + const program = createProgram(); + + // typescript + run('npm install -D typescript'); + run('npx tsc --init'); + + // deps + run('npm install react swr @tanstack/react-query @trpc/server @types/react'); + + await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); + return program; + } + const plugins = [ `plugin prisma { provider = '@core/prisma' @@ -44,9 +60,14 @@ describe('CLI Plugins Tests', () => { provider = '@core/zod' output = 'zod' }`, - `plugin react { - provider = '${path.join(__dirname, '../../../plugins/react/dist')}' - output = 'lib/default-hooks' + `plugin tanstack { + provider = '${path.join(__dirname, '../../../plugins/tanstack-query/dist')}' + output = 'lib/tanstack-query' + target = 'react' + }`, + `plugin swr { + provider = '${path.join(__dirname, '../../../plugins/swr/dist')}' + output = 'lib/swr' }`, `plugin trpc { provider = '${path.join(__dirname, '../../../plugins/trpc/dist')}' @@ -68,17 +89,103 @@ describe('CLI Plugins Tests', () => { }`, ]; - it('all plugins', async () => { - fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); - createNpmrc(); - const program = createProgram(); - await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); + const BASE_MODEL = ` + datasource db { + provider = 'postgresql' + url = env('DATABASE_URL') + } + + enum Role { + USER + ADMIN + } + + model User { + id String @id @default(cuid()) + email String @unique @email + role Role @default(USER) + posts Post[] + @@allow('create', true) + @@allow('all', auth() == this || role == ADMIN) + } + + model Post { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + published Boolean @default(false) + author User? @relation(fields: [authorId], references: [id]) + authorId String? + + @@allow('read', auth() != null && published) + @@allow('all', auth() == author) + } + `; - let schemaContent = fs.readFileSync('schema.zmodel', 'utf-8'); + it('all plugins standard prisma client output path', async () => { + const program = await initProject(); + + let schemaContent = ` +generator client { + provider = "prisma-client-js" +} + +${BASE_MODEL} + `; for (const plugin of plugins) { schemaContent += `\n${plugin}`; } fs.writeFileSync('schema.zmodel', schemaContent); + await program.parseAsync(['generate', '--no-dependency-check'], { from: 'user' }); + + // compile + run('npx tsc'); + }); + + it('all plugins custom prisma client output path', async () => { + const program = await initProject(); + + let schemaContent = ` +generator client { + provider = "prisma-client-js" + output = "foo/bar" +} + +${BASE_MODEL} +`; + for (const plugin of plugins) { + schemaContent += `\n${plugin}`; + } + fs.writeFileSync('schema.zmodel', schemaContent); + + await program.parseAsync(['generate', '--no-dependency-check'], { from: 'user' }); + + // compile + run('npx tsc'); + }); + + it('all plugins absolute prisma client output path', async () => { + const { name: output } = tmp.dirSync({ unsafeCleanup: true }); + console.log('Output prisma client to:', output); + + const program = await initProject(); + + let schemaContent = ` +generator client { + provider = "prisma-client-js" + output = "${output}" +} + +${BASE_MODEL} +`; + for (const plugin of plugins) { + schemaContent += `\n${plugin}`; + } + fs.writeFileSync('schema.zmodel', schemaContent); + + await program.parseAsync(['generate', '--no-dependency-check'], { from: 'user' }); + + // compile + run('npx tsc'); }); }); diff --git a/packages/schema/tests/plugins/policy.test.ts b/packages/schema/tests/plugins/policy.test.ts index 6666abb5e..9b2ab818b 100644 --- a/packages/schema/tests/plugins/policy.test.ts +++ b/packages/schema/tests/plugins/policy.test.ts @@ -1,3 +1,5 @@ +/// + import { loadSchema } from '@zenstackhq/testtools'; describe('Policy plugin tests', () => { diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index e82960d44..53da49054 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -3,3 +3,4 @@ export * from './constants'; export * from './types'; export * from './utils'; export * from './policy'; +export * from './prisma'; diff --git a/packages/sdk/src/prisma.ts b/packages/sdk/src/prisma.ts new file mode 100644 index 000000000..8558facdd --- /dev/null +++ b/packages/sdk/src/prisma.ts @@ -0,0 +1,58 @@ +import { GeneratorDecl, Model, Plugin, isGeneratorDecl, isPlugin } from './ast'; +import { getLiteral } from './utils'; +import path from 'path'; + +/** + * Given a ZModel and an import context directory, compute the import spec for the Prisma Client. + */ +export function getPrismaClientImportSpec(model: Model, importingFromDir: string) { + const generator = model.declarations.find( + (d) => + isGeneratorDecl(d) && + d.fields.some((f) => f.name === 'provider' && getLiteral(f.value) === 'prisma-client-js') + ) as GeneratorDecl; + + const clientOutputField = generator?.fields.find((f) => f.name === 'output'); + const clientOutput = getLiteral(clientOutputField?.value); + + if (!clientOutput) { + // no user-declared Prisma Client output location + return '@prisma/client'; + } + + if (path.isAbsolute(clientOutput)) { + // absolute path + return clientOutput; + } + + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const zmodelDir = path.dirname(model.$document!.uri.fsPath); + + // compute prisma schema absolute output path + let prismaSchemaOutputDir = path.resolve(zmodelDir, './prisma'); + const prismaPlugin = model.declarations.find( + (d) => isPlugin(d) && d.fields.some((f) => f.name === 'provider' && getLiteral(f.value) === '@core/prisma') + ) as Plugin; + if (prismaPlugin) { + const output = getLiteral(prismaPlugin.fields.find((f) => f.name === 'output')?.value); + if (output) { + if (path.isAbsolute(output)) { + // absolute prisma schema output path + prismaSchemaOutputDir = path.dirname(output); + } else { + prismaSchemaOutputDir = path.dirname(path.resolve(zmodelDir, output)); + } + } + } + + // resolve the prisma client output path, which is relative to the prisma schema + const resolvedPrismaClientOutput = path.resolve(prismaSchemaOutputDir, clientOutput); + + // DEBUG: + // console.log('PRISMA SCHEMA PATH:', prismaSchemaOutputDir); + // console.log('PRISMA CLIENT PATH:', resolvedPrismaClientOutput); + // console.log('IMPORTING PATH:', importingFromDir); + + // compute prisma client absolute output dir relative to the importing file + return path.relative(importingFromDir, resolvedPrismaClientOutput); +}