Skip to content

Commit 9daff09

Browse files
committed
fix(policy): properly handle array-form of upsert payload
Fixes #1080
1 parent 4dd7aa0 commit 9daff09

File tree

4 files changed

+204
-39
lines changed

4 files changed

+204
-39
lines changed

packages/runtime/src/cross/nested-write-visitor.ts

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import type { FieldInfo, ModelMeta } from './model-meta';
55
import { resolveField } from './model-meta';
66
import { MaybePromise, PrismaWriteActionType, PrismaWriteActions } from './types';
7-
import { enumerate, getModelFields } from './utils';
7+
import { getModelFields } from './utils';
88

99
type NestingPathItem = { field?: FieldInfo; model: string; where: any; unique: boolean };
1010

@@ -155,7 +155,7 @@ export class NestedWriteVisitor {
155155
// visit payload
156156
switch (action) {
157157
case 'create':
158-
for (const item of enumerate(data)) {
158+
for (const item of this.enumerateReverse(data)) {
159159
const newContext = pushNewContext(field, model, {});
160160
let callbackResult: any;
161161
if (this.callback.create) {
@@ -183,7 +183,7 @@ export class NestedWriteVisitor {
183183
break;
184184

185185
case 'connectOrCreate':
186-
for (const item of enumerate(data)) {
186+
for (const item of this.enumerateReverse(data)) {
187187
const newContext = pushNewContext(field, model, item.where);
188188
let callbackResult: any;
189189
if (this.callback.connectOrCreate) {
@@ -198,7 +198,7 @@ export class NestedWriteVisitor {
198198

199199
case 'connect':
200200
if (this.callback.connect) {
201-
for (const item of enumerate(data)) {
201+
for (const item of this.enumerateReverse(data)) {
202202
const newContext = pushNewContext(field, model, item, true);
203203
await this.callback.connect(model, item, newContext);
204204
}
@@ -210,7 +210,7 @@ export class NestedWriteVisitor {
210210
// if relation is to-many, the payload is a unique filter object
211211
// if relation is to-one, the payload can only be boolean `true`
212212
if (this.callback.disconnect) {
213-
for (const item of enumerate(data)) {
213+
for (const item of this.enumerateReverse(data)) {
214214
const newContext = pushNewContext(field, model, item, typeof item === 'object');
215215
await this.callback.disconnect(model, item, newContext);
216216
}
@@ -225,7 +225,7 @@ export class NestedWriteVisitor {
225225
break;
226226

227227
case 'update':
228-
for (const item of enumerate(data)) {
228+
for (const item of this.enumerateReverse(data)) {
229229
const newContext = pushNewContext(field, model, item.where);
230230
let callbackResult: any;
231231
if (this.callback.update) {
@@ -244,7 +244,7 @@ export class NestedWriteVisitor {
244244
break;
245245

246246
case 'updateMany':
247-
for (const item of enumerate(data)) {
247+
for (const item of this.enumerateReverse(data)) {
248248
const newContext = pushNewContext(field, model, item.where);
249249
let callbackResult: any;
250250
if (this.callback.updateMany) {
@@ -258,7 +258,7 @@ export class NestedWriteVisitor {
258258
break;
259259

260260
case 'upsert': {
261-
for (const item of enumerate(data)) {
261+
for (const item of this.enumerateReverse(data)) {
262262
const newContext = pushNewContext(field, model, item.where);
263263
let callbackResult: any;
264264
if (this.callback.upsert) {
@@ -278,7 +278,7 @@ export class NestedWriteVisitor {
278278

279279
case 'delete': {
280280
if (this.callback.delete) {
281-
for (const item of enumerate(data)) {
281+
for (const item of this.enumerateReverse(data)) {
282282
const newContext = pushNewContext(field, model, toplevel ? item.where : item);
283283
await this.callback.delete(model, item, newContext);
284284
}
@@ -288,7 +288,7 @@ export class NestedWriteVisitor {
288288

289289
case 'deleteMany':
290290
if (this.callback.deleteMany) {
291-
for (const item of enumerate(data)) {
291+
for (const item of this.enumerateReverse(data)) {
292292
const newContext = pushNewContext(field, model, toplevel ? item.where : item);
293293
await this.callback.deleteMany(model, item, newContext);
294294
}
@@ -336,4 +336,16 @@ export class NestedWriteVisitor {
336336
}
337337
}
338338
}
339+
340+
// enumerate a (possible) array in reverse order, so that the enumeration
341+
// callback can safely delete the current item
342+
private *enumerateReverse(data: any) {
343+
if (Array.isArray(data)) {
344+
for (let i = data.length - 1; i >= 0; i--) {
345+
yield data[i];
346+
}
347+
} else {
348+
yield data;
349+
}
350+
}
339351
}

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -343,29 +343,19 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
343343
}
344344
}
345345

346-
if (context.parent.connect) {
347-
// if the payload parent already has a "connect" clause, merge it
348-
if (Array.isArray(context.parent.connect)) {
349-
context.parent.connect.push(args.where);
350-
} else {
351-
context.parent.connect = [context.parent.connect, args.where];
352-
}
353-
} else {
354-
// otherwise, create a new "connect" clause
355-
context.parent.connect = args.where;
356-
}
346+
this.mergeToParent(context.parent, 'connect', args.where);
357347
// record the key of connected entities so we can avoid validating them later
358348
connectedEntities.add(getEntityKey(model, existing));
359349
} else {
360350
// create case
361351
pushIdFields(model, context);
362352

363353
// create a new "create" clause at the parent level
364-
context.parent.create = args.create;
354+
this.mergeToParent(context.parent, 'create', args.create);
365355
}
366356

367357
// remove the connectOrCreate clause
368-
delete context.parent['connectOrCreate'];
358+
this.removeFromParent(context.parent, 'connectOrCreate', args);
369359

370360
// return false to prevent visiting the nested payload
371361
return false;
@@ -895,7 +885,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
895885
await _create(model, args, context);
896886

897887
// remove it from the update payload
898-
delete context.parent.create;
888+
this.removeFromParent(context.parent, 'create', args);
899889

900890
// don't visit payload
901891
return false;
@@ -928,22 +918,23 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
928918
await _registerPostUpdateCheck(model, uniqueFilter);
929919

930920
// convert upsert to update
931-
context.parent.update = {
921+
const convertedUpdate = {
932922
where: args.where,
933923
data: this.validateUpdateInputSchema(model, args.update),
934924
};
935-
delete context.parent.upsert;
925+
this.mergeToParent(context.parent, 'update', convertedUpdate);
926+
this.removeFromParent(context.parent, 'upsert', args);
936927

937928
// continue visiting the new payload
938-
return context.parent.update;
929+
return convertedUpdate;
939930
} else {
940931
// create case
941932

942933
// process the entire create subtree separately
943934
await _create(model, args.create, context);
944935

945936
// remove it from the update payload
946-
delete context.parent.upsert;
937+
this.removeFromParent(context.parent, 'upsert', args);
947938

948939
// don't visit payload
949940
return false;
@@ -1390,5 +1381,31 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
13901381
return requireField(this.modelMeta, fieldInfo.type, fieldInfo.backLink);
13911382
}
13921383

1384+
private mergeToParent(parent: any, key: string, value: any) {
1385+
if (parent[key]) {
1386+
if (Array.isArray(parent[key])) {
1387+
parent[key].push(value);
1388+
} else {
1389+
parent[key] = [parent[key], value];
1390+
}
1391+
} else {
1392+
parent[key] = value;
1393+
}
1394+
}
1395+
1396+
private removeFromParent(parent: any, key: string, data: any) {
1397+
if (parent[key] === data) {
1398+
delete parent[key];
1399+
} else if (Array.isArray(parent[key])) {
1400+
const idx = parent[key].indexOf(data);
1401+
if (idx >= 0) {
1402+
parent[key].splice(idx, 1);
1403+
if (parent[key].length === 0) {
1404+
delete parent[key];
1405+
}
1406+
}
1407+
}
1408+
}
1409+
13931410
//#endregion
13941411
}

tests/integration/tests/regression/issue-1078.test.ts

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { loadSchema } from '@zenstackhq/testtools';
22

33
describe('issue 1078', () => {
44
it('regression', async () => {
5-
const { prisma, enhance } = await loadSchema(
5+
const { enhance } = await loadSchema(
66
`
77
model Counter {
88
id String @id
@@ -12,21 +12,25 @@ describe('issue 1078', () => {
1212
1313
@@validate(value >= 0)
1414
@@allow('all', true)
15-
}
15+
}
1616
`
1717
);
1818

1919
const db = enhance();
2020

21-
const counter = await db.counter.create({
22-
data: { id: '1', name: 'It should create', value: 1 },
23-
});
21+
await expect(
22+
db.counter.create({
23+
data: { id: '1', name: 'It should create', value: 1 },
24+
})
25+
).toResolveTruthy();
2426

2527
//! This query fails validation
26-
const updated = await db.counter.update({
27-
where: { id: '1' },
28-
data: { name: 'It should update' },
29-
});
28+
await expect(
29+
db.counter.update({
30+
where: { id: '1' },
31+
data: { name: 'It should update' },
32+
})
33+
).toResolveTruthy();
3034
});
3135

3236
it('read', async () => {
@@ -37,8 +41,7 @@ describe('issue 1078', () => {
3741
title String @allow('read', true, true)
3842
content String
3943
}
40-
`,
41-
{ logPrismaQuery: true }
44+
`
4245
);
4346

4447
const db = enhance();
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('issue 1080', () => {
4+
it('regression', async () => {
5+
const { enhance } = await loadSchema(
6+
`
7+
model Project {
8+
id String @id @unique @default(uuid())
9+
Fields Field[]
10+
11+
@@allow('all', true)
12+
}
13+
14+
model Field {
15+
id String @id @unique @default(uuid())
16+
name String
17+
Project Project @relation(fields: [projectId], references: [id])
18+
projectId String
19+
20+
@@allow('all', true)
21+
}
22+
`,
23+
{ logPrismaQuery: true }
24+
);
25+
26+
const db = enhance();
27+
28+
const project = await db.project.create({
29+
include: { Fields: true },
30+
data: {
31+
Fields: {
32+
create: [{ name: 'first' }, { name: 'second' }],
33+
},
34+
},
35+
});
36+
37+
let updated = await db.project.update({
38+
where: { id: project.id },
39+
include: { Fields: true },
40+
data: {
41+
Fields: {
42+
upsert: [
43+
{
44+
where: { id: project.Fields[0].id },
45+
create: { name: 'first1' },
46+
update: { name: 'first1' },
47+
},
48+
{
49+
where: { id: project.Fields[1].id },
50+
create: { name: 'second1' },
51+
update: { name: 'second1' },
52+
},
53+
],
54+
},
55+
},
56+
});
57+
expect(updated).toMatchObject({
58+
Fields: expect.arrayContaining([
59+
expect.objectContaining({ name: 'first1' }),
60+
expect.objectContaining({ name: 'second1' }),
61+
]),
62+
});
63+
64+
updated = await db.project.update({
65+
where: { id: project.id },
66+
include: { Fields: true },
67+
data: {
68+
Fields: {
69+
upsert: {
70+
where: { id: project.Fields[0].id },
71+
create: { name: 'first2' },
72+
update: { name: 'first2' },
73+
},
74+
},
75+
},
76+
});
77+
expect(updated).toMatchObject({
78+
Fields: expect.arrayContaining([
79+
expect.objectContaining({ name: 'first2' }),
80+
expect.objectContaining({ name: 'second1' }),
81+
]),
82+
});
83+
84+
updated = await db.project.update({
85+
where: { id: project.id },
86+
include: { Fields: true },
87+
data: {
88+
Fields: {
89+
upsert: {
90+
where: { id: project.Fields[0].id },
91+
create: { name: 'first3' },
92+
update: { name: 'first3' },
93+
},
94+
update: {
95+
where: { id: project.Fields[1].id },
96+
data: { name: 'second3' },
97+
},
98+
},
99+
},
100+
});
101+
expect(updated).toMatchObject({
102+
Fields: expect.arrayContaining([
103+
expect.objectContaining({ name: 'first3' }),
104+
expect.objectContaining({ name: 'second3' }),
105+
]),
106+
});
107+
108+
updated = await db.project.update({
109+
where: { id: project.id },
110+
include: { Fields: true },
111+
data: {
112+
Fields: {
113+
upsert: {
114+
where: { id: 'non-exist' },
115+
create: { name: 'third1' },
116+
update: { name: 'third1' },
117+
},
118+
update: {
119+
where: { id: project.Fields[1].id },
120+
data: { name: 'second4' },
121+
},
122+
},
123+
},
124+
});
125+
expect(updated).toMatchObject({
126+
Fields: expect.arrayContaining([
127+
expect.objectContaining({ name: 'first3' }),
128+
expect.objectContaining({ name: 'second4' }),
129+
expect.objectContaining({ name: 'third1' }),
130+
]),
131+
});
132+
});
133+
});

0 commit comments

Comments
 (0)