@@ -109,6 +109,51 @@ def _symbol_extractor(x):
109
109
"information." )
110
110
111
111
112
+ def _serialize_controls (gate ):
113
+ """Helper to serialize control qubits if applicable."""
114
+ if hasattr (gate , '_tfq_control_qubits' ):
115
+ return ',' .join (
116
+ v2 .qubit_to_proto_id (q ) for q in gate ._tfq_control_qubits )
117
+ return ''
118
+
119
+
120
+ def _serialize_control_vals (gate ):
121
+ """Helper to serialize control values if applicable.."""
122
+ if hasattr (gate , '_tfq_control_values' ):
123
+ return ',' .join (str (v [0 ]) for v in gate ._tfq_control_values )
124
+ return ''
125
+
126
+
127
+ class DelayedAssignmentGate (cirq .Gate ):
128
+ """Class to do control qubit assignment before sub_gate qubit assignment."""
129
+
130
+ def __init__ (self , gate_callable , control_qubits , control_values ):
131
+ self ._gate_callable = gate_callable
132
+ self ._control_qubits = control_qubits
133
+ self ._control_values = control_values
134
+
135
+ def _qid_shape_ (self ):
136
+ raise ValueError ("Called qid_shape on workaround class." )
137
+
138
+ # pylint: disable=invalid-name
139
+ def on (self , * qubits ):
140
+ """Returns gate_callable on qubits controlled by contol_qubits."""
141
+ return self ._gate_callable (* qubits ).controlled_by (
142
+ * self ._control_qubits , control_values = self ._control_values )
143
+
144
+ # pylint: enable=invalid-name
145
+
146
+
147
+ def _optional_control_promote (gate , qubits_message , values_message ):
148
+ """Optionally promote to controlled gate based on serialized control msg."""
149
+ if qubits_message == '' and values_message == '' :
150
+ return gate
151
+ qbs = [v2 .qubit_from_proto_id (qb ) for qb in qubits_message .split (',' )]
152
+ vals = [int (cv ) for cv in values_message .split (',' )]
153
+
154
+ return DelayedAssignmentGate (gate , qbs , vals )
155
+
156
+
112
157
def _eigen_gate_serializer (gate_type , serialized_id ):
113
158
"""Make standard serializer for eigen gates."""
114
159
@@ -124,7 +169,14 @@ def _eigen_gate_serializer(gate_type, serialized_id):
124
169
cirq .google .SerializingArg (
125
170
serialized_name = "global_shift" ,
126
171
serialized_type = float ,
127
- op_getter = lambda x : float (x .gate ._global_shift ))
172
+ op_getter = lambda x : float (x .gate ._global_shift )),
173
+ cirq .google .SerializingArg (serialized_name = "control_qubits" ,
174
+ serialized_type = str ,
175
+ op_getter = lambda x : _serialize_controls (x )),
176
+ cirq .google .SerializingArg (
177
+ serialized_name = "control_values" ,
178
+ serialized_type = str ,
179
+ op_getter = lambda x : _serialize_control_vals (x ))
128
180
]
129
181
return cirq .google .GateOpSerializer (gate_type = gate_type ,
130
182
serialized_gate_id = serialized_id ,
@@ -135,26 +187,35 @@ def _eigen_gate_serializer(gate_type, serialized_id):
135
187
def _eigen_gate_deserializer (gate_type , serialized_id ):
136
188
"""Make standard deserializer for eigen gates."""
137
189
138
- def _scalar_combiner (exponent , global_shift , exponent_scalar ):
190
+ def _scalar_combiner (exponent , global_shift , exponent_scalar ,
191
+ control_qubits , control_values ):
139
192
"""This is a workaround to support symbol scalar multiplication.
140
193
In the future we should likely get rid of this in favor of proper
141
194
expression parsing once cirq supports it. See cirq.op_serializer
142
195
and cirq's program protobuf for details. This is needed for things
143
196
like cirq.rx('alpha').
144
197
"""
145
198
if exponent_scalar == 1.0 :
146
- return gate_type (exponent = _round (exponent ),
147
- global_shift = _round (global_shift ))
148
- return gate_type (exponent = _round (exponent ) * _round (exponent_scalar ),
149
- global_shift = _round (global_shift ))
199
+ return _optional_control_promote (
200
+ gate_type (exponent = _round (exponent ),
201
+ global_shift = _round (global_shift )), control_qubits ,
202
+ control_values )
203
+ return _optional_control_promote (
204
+ gate_type (exponent = _round (exponent ) * _round (exponent_scalar ),
205
+ global_shift = _round (global_shift )), control_qubits ,
206
+ control_values )
150
207
151
208
args = [
152
209
cirq .google .DeserializingArg (serialized_name = "exponent" ,
153
210
constructor_arg_name = "exponent" ),
154
211
cirq .google .DeserializingArg (serialized_name = "global_shift" ,
155
212
constructor_arg_name = "global_shift" ),
156
213
cirq .google .DeserializingArg (serialized_name = "exponent_scalar" ,
157
- constructor_arg_name = "exponent_scalar" )
214
+ constructor_arg_name = "exponent_scalar" ),
215
+ cirq .google .DeserializingArg (serialized_name = "control_qubits" ,
216
+ constructor_arg_name = "control_qubits" ),
217
+ cirq .google .DeserializingArg (serialized_name = "control_values" ,
218
+ constructor_arg_name = "control_values" )
158
219
]
159
220
return cirq .google .GateOpDeserializer (serialized_gate_id = serialized_id ,
160
221
gate_constructor = _scalar_combiner ,
@@ -181,6 +242,13 @@ def _fsim_gate_serializer():
181
242
serialized_name = "phi_scalar" ,
182
243
serialized_type = float ,
183
244
op_getter = lambda x : _scalar_extractor (x .gate .phi )),
245
+ cirq .google .SerializingArg (serialized_name = "control_qubits" ,
246
+ serialized_type = str ,
247
+ op_getter = lambda x : _serialize_controls (x )),
248
+ cirq .google .SerializingArg (
249
+ serialized_name = "control_values" ,
250
+ serialized_type = str ,
251
+ op_getter = lambda x : _serialize_control_vals (x ))
184
252
]
185
253
return cirq .google .GateOpSerializer (gate_type = cirq .FSimGate ,
186
254
serialized_gate_id = "FSIM" ,
@@ -191,12 +259,15 @@ def _fsim_gate_serializer():
191
259
def _fsim_gate_deserializer ():
192
260
"""Make standard deserializer for fsim gate."""
193
261
194
- def _scalar_combiner (theta , theta_scalar , phi , phi_scalar ):
262
+ def _scalar_combiner (theta , theta_scalar , phi , phi_scalar , control_qubits ,
263
+ control_values ):
195
264
"""This is a workaround to support symbol scalar multiplication.
196
265
See `_eigen_gate_deserializer` for details.
197
266
"""
198
- return cirq .FSimGate (theta = _round (theta ) * _round (theta_scalar ),
199
- phi = _round (phi ) * _round (phi_scalar ))
267
+ return _optional_control_promote (
268
+ cirq .FSimGate (theta = _round (theta ) * _round (theta_scalar ),
269
+ phi = _round (phi ) * _round (phi_scalar )), control_qubits ,
270
+ control_values )
200
271
201
272
args = [
202
273
cirq .google .DeserializingArg (serialized_name = "theta" ,
@@ -207,6 +278,10 @@ def _scalar_combiner(theta, theta_scalar, phi, phi_scalar):
207
278
constructor_arg_name = "theta_scalar" ),
208
279
cirq .google .DeserializingArg (serialized_name = "phi_scalar" ,
209
280
constructor_arg_name = "phi_scalar" ),
281
+ cirq .google .DeserializingArg (serialized_name = "control_qubits" ,
282
+ constructor_arg_name = "control_qubits" ),
283
+ cirq .google .DeserializingArg (serialized_name = "control_values" ,
284
+ constructor_arg_name = "control_values" )
210
285
]
211
286
return cirq .google .GateOpDeserializer (serialized_gate_id = "FSIM" ,
212
287
gate_constructor = _scalar_combiner ,
@@ -228,7 +303,14 @@ def _identity_check(x):
228
303
args = [
229
304
cirq .google .SerializingArg (serialized_name = "unused" ,
230
305
serialized_type = bool ,
231
- op_getter = _identity_check )
306
+ op_getter = _identity_check ),
307
+ cirq .google .SerializingArg (serialized_name = "control_qubits" ,
308
+ serialized_type = str ,
309
+ op_getter = lambda x : _serialize_controls (x )),
310
+ cirq .google .SerializingArg (
311
+ serialized_name = "control_values" ,
312
+ serialized_type = str ,
313
+ op_getter = lambda x : _serialize_control_vals (x ))
232
314
]
233
315
return cirq .google .GateOpSerializer (gate_type = cirq .IdentityGate ,
234
316
serialized_gate_id = "I" ,
@@ -240,11 +322,15 @@ def _identity_gate_deserializer():
240
322
"""Make a standard deserializer for the single qubit identity."""
241
323
args = [
242
324
cirq .google .DeserializingArg (serialized_name = "unused" ,
243
- constructor_arg_name = "unused" )
325
+ constructor_arg_name = "unused" ),
326
+ cirq .google .DeserializingArg (serialized_name = "control_qubits" ,
327
+ constructor_arg_name = "control_qubits" ),
328
+ cirq .google .DeserializingArg (serialized_name = "control_values" ,
329
+ constructor_arg_name = "control_values" )
244
330
]
245
331
246
- def _cirq_i_workaround (unused ):
247
- return cirq .I
332
+ def _cirq_i_workaround (unused , control_qubits , control_values ):
333
+ return _optional_control_promote ( cirq .I , control_qubits , control_values )
248
334
249
335
return cirq .google .GateOpDeserializer (serialized_gate_id = "I" ,
250
336
gate_constructor = _cirq_i_workaround ,
@@ -274,7 +360,14 @@ def _phased_eigen_gate_serializer(gate_type, serialized_id):
274
360
cirq .google .SerializingArg (
275
361
serialized_name = "global_shift" ,
276
362
serialized_type = float ,
277
- op_getter = lambda x : float (x .gate .global_shift ))
363
+ op_getter = lambda x : float (x .gate .global_shift )),
364
+ cirq .google .SerializingArg (serialized_name = "control_qubits" ,
365
+ serialized_type = str ,
366
+ op_getter = lambda x : _serialize_controls (x )),
367
+ cirq .google .SerializingArg (
368
+ serialized_name = "control_values" ,
369
+ serialized_type = str ,
370
+ op_getter = lambda x : _serialize_control_vals (x ))
278
371
]
279
372
return cirq .google .GateOpSerializer (gate_type = gate_type ,
280
373
serialized_gate_id = serialized_id ,
@@ -286,7 +379,8 @@ def _phased_eigen_gate_deserializer(gate_type, serialized_id):
286
379
"""Make a standard deserializer for phased eigen gates."""
287
380
288
381
def _scalar_combiner (exponent , global_shift , exponent_scalar ,
289
- phase_exponent , phase_exponent_scalar ):
382
+ phase_exponent , phase_exponent_scalar , control_qubits ,
383
+ control_values ):
290
384
"""This is a workaround to support symbol scalar multiplication.
291
385
In the future we should likely get rid of this in favor of proper
292
386
expression parsing once cirq supports it. See cirq.op_serializer
@@ -302,10 +396,14 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
302
396
if global_shift != 0 :
303
397
# needed in case this specific phasedeigengate doesn't
304
398
# have a global_phase in constructor.
305
- return gate_type (exponent = exponent ,
306
- global_shift = _round (global_shift ),
307
- phase_exponent = phase_exponent )
308
- return gate_type (exponent = exponent , phase_exponent = phase_exponent )
399
+ return _optional_control_promote (
400
+ gate_type (exponent = exponent ,
401
+ global_shift = _round (global_shift ),
402
+ phase_exponent = phase_exponent ), control_qubits ,
403
+ control_values )
404
+ return _optional_control_promote (
405
+ gate_type (exponent = exponent , phase_exponent = phase_exponent ),
406
+ control_qubits , control_values )
309
407
310
408
args = [
311
409
cirq .google .DeserializingArg (serialized_name = "phase_exponent" ,
@@ -319,6 +417,10 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
319
417
constructor_arg_name = "exponent_scalar" ),
320
418
cirq .google .DeserializingArg (serialized_name = "global_shift" ,
321
419
constructor_arg_name = "global_shift" ),
420
+ cirq .google .DeserializingArg (serialized_name = "control_qubits" ,
421
+ constructor_arg_name = "control_qubits" ),
422
+ cirq .google .DeserializingArg (serialized_name = "control_values" ,
423
+ constructor_arg_name = "control_values" )
322
424
]
323
425
return cirq .google .GateOpDeserializer (serialized_gate_id = serialized_id ,
324
426
gate_constructor = _scalar_combiner ,
@@ -434,6 +536,21 @@ def serialize_circuit(circuit_inp):
434
536
old_moment .operations ))
435
537
circuit [moment_ind ] = new_moment
436
538
539
+ # Demote cirq.controlled_operations (controlled gates) to their sub_gate
540
+ # types with _tfq_control_qubits and _tfq_control_values fields so that
541
+ # the gates can still get picked up by the serializer. There would be no way
542
+ # to discern controlledgates from one another otherwise. This
543
+ # "momentary demotion" occurs with the help of the DelayedAssignmentGate.
544
+ for i , moment in enumerate (circuit ):
545
+ for op in moment :
546
+ if isinstance (op ,
547
+ cirq .ops .controlled_operation .ControlledOperation ):
548
+ tfq_compatible = op .sub_operation
549
+ tfq_compatible ._tfq_control_qubits = op .controls
550
+ tfq_compatible ._tfq_control_values = op .control_values
551
+ dropped_moment = moment .without_operations_touching (op .qubits )
552
+ circuit [i ] = dropped_moment .with_operation (tfq_compatible )
553
+
437
554
return SERIALIZER .serialize (circuit )
438
555
439
556
0 commit comments