Skip to content

Commit ef22294

Browse files
Serialize multiqubit (#447)
* Add multi qubit gate serialization support. * small format. * format. * yet more format fixes. * A Feedback.
1 parent 7b8728d commit ef22294

File tree

3 files changed

+221
-24
lines changed

3 files changed

+221
-24
lines changed

tensorflow_quantum/core/ops/tfq_ps_decompose_op.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
176176
new_op_map["exponent_scalar"].mutable_arg_value()->set_float_value(
177177
cur_exponent_scalar * -0.5);
178178
new_op_map["exponent"].set_symbol(symbol);
179+
// Copy over control metadata.
180+
new_op_map["control_qubits"].mutable_arg_value()->set_string_value(
181+
cur_op_map["control_qubits"].arg_value().string_value());
182+
new_op_map["control_values"].mutable_arg_value()->set_string_value(
183+
cur_op_map["control_values"].arg_value().string_value());
179184
// Step 4. add qubits.
180185
*new_op.mutable_qubits() = {cur_op_qubits.begin(), cur_op_qubits.end()};
181186
return new_op;
@@ -215,6 +220,11 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
215220
}
216221
// Step 4. add qubits.
217222
*new_op.mutable_qubits() = {cur_op_qubits.begin(), cur_op_qubits.end()};
223+
// Copy over control metadata.
224+
new_op_map["control_qubits"].mutable_arg_value()->set_string_value(
225+
cur_op_map["control_qubits"].arg_value().string_value());
226+
new_op_map["control_values"].mutable_arg_value()->set_string_value(
227+
cur_op_map["control_values"].arg_value().string_value());
218228
return new_op;
219229
}
220230

@@ -251,6 +261,11 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
251261
}
252262
*new_op.mutable_qubits() = {cur_op_qubits.begin() + use_target,
253263
cur_op_qubits.end() - !use_target};
264+
// Copy over control metadata.
265+
new_op_map["control_qubits"].mutable_arg_value()->set_string_value(
266+
cur_op_map["control_qubits"].arg_value().string_value());
267+
new_op_map["control_values"].mutable_arg_value()->set_string_value(
268+
cur_op_map["control_values"].arg_value().string_value());
254269
return new_op;
255270
}
256271

@@ -290,6 +305,11 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
290305
}
291306
// Step 4. add qubits.
292307
*new_op.mutable_qubits() = {cur_op_qubits.begin(), cur_op_qubits.end()};
308+
// Copy over control metadata.
309+
new_op_map["control_qubits"].mutable_arg_value()->set_string_value(
310+
cur_op_map["control_qubits"].arg_value().string_value());
311+
new_op_map["control_values"].mutable_arg_value()->set_string_value(
312+
cur_op_map["control_values"].arg_value().string_value());
293313
return new_op;
294314
}
295315
};

tensorflow_quantum/core/serialize/serializer.py

Lines changed: 137 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,51 @@ def _symbol_extractor(x):
109109
"information.")
110110

111111

112+
def _serialize_controls(gate):
113+
"""Helper to serialize control qubits if applicable."""
114+
if hasattr(gate, '_tfq_control_qubits'):
115+
return ','.join(
116+
v2.qubit_to_proto_id(q) for q in gate._tfq_control_qubits)
117+
return ''
118+
119+
120+
def _serialize_control_vals(gate):
121+
"""Helper to serialize control values if applicable.."""
122+
if hasattr(gate, '_tfq_control_values'):
123+
return ','.join(str(v[0]) for v in gate._tfq_control_values)
124+
return ''
125+
126+
127+
class DelayedAssignmentGate(cirq.Gate):
128+
"""Class to do control qubit assignment before sub_gate qubit assignment."""
129+
130+
def __init__(self, gate_callable, control_qubits, control_values):
131+
self._gate_callable = gate_callable
132+
self._control_qubits = control_qubits
133+
self._control_values = control_values
134+
135+
def _qid_shape_(self):
136+
raise ValueError("Called qid_shape on workaround class.")
137+
138+
# pylint: disable=invalid-name
139+
def on(self, *qubits):
140+
"""Returns gate_callable on qubits controlled by contol_qubits."""
141+
return self._gate_callable(*qubits).controlled_by(
142+
*self._control_qubits, control_values=self._control_values)
143+
144+
# pylint: enable=invalid-name
145+
146+
147+
def _optional_control_promote(gate, qubits_message, values_message):
148+
"""Optionally promote to controlled gate based on serialized control msg."""
149+
if qubits_message == '' and values_message == '':
150+
return gate
151+
qbs = [v2.qubit_from_proto_id(qb) for qb in qubits_message.split(',')]
152+
vals = [int(cv) for cv in values_message.split(',')]
153+
154+
return DelayedAssignmentGate(gate, qbs, vals)
155+
156+
112157
def _eigen_gate_serializer(gate_type, serialized_id):
113158
"""Make standard serializer for eigen gates."""
114159

@@ -124,7 +169,14 @@ def _eigen_gate_serializer(gate_type, serialized_id):
124169
cirq.google.SerializingArg(
125170
serialized_name="global_shift",
126171
serialized_type=float,
127-
op_getter=lambda x: float(x.gate._global_shift))
172+
op_getter=lambda x: float(x.gate._global_shift)),
173+
cirq.google.SerializingArg(serialized_name="control_qubits",
174+
serialized_type=str,
175+
op_getter=lambda x: _serialize_controls(x)),
176+
cirq.google.SerializingArg(
177+
serialized_name="control_values",
178+
serialized_type=str,
179+
op_getter=lambda x: _serialize_control_vals(x))
128180
]
129181
return cirq.google.GateOpSerializer(gate_type=gate_type,
130182
serialized_gate_id=serialized_id,
@@ -135,26 +187,35 @@ def _eigen_gate_serializer(gate_type, serialized_id):
135187
def _eigen_gate_deserializer(gate_type, serialized_id):
136188
"""Make standard deserializer for eigen gates."""
137189

138-
def _scalar_combiner(exponent, global_shift, exponent_scalar):
190+
def _scalar_combiner(exponent, global_shift, exponent_scalar,
191+
control_qubits, control_values):
139192
"""This is a workaround to support symbol scalar multiplication.
140193
In the future we should likely get rid of this in favor of proper
141194
expression parsing once cirq supports it. See cirq.op_serializer
142195
and cirq's program protobuf for details. This is needed for things
143196
like cirq.rx('alpha').
144197
"""
145198
if exponent_scalar == 1.0:
146-
return gate_type(exponent=_round(exponent),
147-
global_shift=_round(global_shift))
148-
return gate_type(exponent=_round(exponent) * _round(exponent_scalar),
149-
global_shift=_round(global_shift))
199+
return _optional_control_promote(
200+
gate_type(exponent=_round(exponent),
201+
global_shift=_round(global_shift)), control_qubits,
202+
control_values)
203+
return _optional_control_promote(
204+
gate_type(exponent=_round(exponent) * _round(exponent_scalar),
205+
global_shift=_round(global_shift)), control_qubits,
206+
control_values)
150207

151208
args = [
152209
cirq.google.DeserializingArg(serialized_name="exponent",
153210
constructor_arg_name="exponent"),
154211
cirq.google.DeserializingArg(serialized_name="global_shift",
155212
constructor_arg_name="global_shift"),
156213
cirq.google.DeserializingArg(serialized_name="exponent_scalar",
157-
constructor_arg_name="exponent_scalar")
214+
constructor_arg_name="exponent_scalar"),
215+
cirq.google.DeserializingArg(serialized_name="control_qubits",
216+
constructor_arg_name="control_qubits"),
217+
cirq.google.DeserializingArg(serialized_name="control_values",
218+
constructor_arg_name="control_values")
158219
]
159220
return cirq.google.GateOpDeserializer(serialized_gate_id=serialized_id,
160221
gate_constructor=_scalar_combiner,
@@ -181,6 +242,13 @@ def _fsim_gate_serializer():
181242
serialized_name="phi_scalar",
182243
serialized_type=float,
183244
op_getter=lambda x: _scalar_extractor(x.gate.phi)),
245+
cirq.google.SerializingArg(serialized_name="control_qubits",
246+
serialized_type=str,
247+
op_getter=lambda x: _serialize_controls(x)),
248+
cirq.google.SerializingArg(
249+
serialized_name="control_values",
250+
serialized_type=str,
251+
op_getter=lambda x: _serialize_control_vals(x))
184252
]
185253
return cirq.google.GateOpSerializer(gate_type=cirq.FSimGate,
186254
serialized_gate_id="FSIM",
@@ -191,12 +259,15 @@ def _fsim_gate_serializer():
191259
def _fsim_gate_deserializer():
192260
"""Make standard deserializer for fsim gate."""
193261

194-
def _scalar_combiner(theta, theta_scalar, phi, phi_scalar):
262+
def _scalar_combiner(theta, theta_scalar, phi, phi_scalar, control_qubits,
263+
control_values):
195264
"""This is a workaround to support symbol scalar multiplication.
196265
See `_eigen_gate_deserializer` for details.
197266
"""
198-
return cirq.FSimGate(theta=_round(theta) * _round(theta_scalar),
199-
phi=_round(phi) * _round(phi_scalar))
267+
return _optional_control_promote(
268+
cirq.FSimGate(theta=_round(theta) * _round(theta_scalar),
269+
phi=_round(phi) * _round(phi_scalar)), control_qubits,
270+
control_values)
200271

201272
args = [
202273
cirq.google.DeserializingArg(serialized_name="theta",
@@ -207,6 +278,10 @@ def _scalar_combiner(theta, theta_scalar, phi, phi_scalar):
207278
constructor_arg_name="theta_scalar"),
208279
cirq.google.DeserializingArg(serialized_name="phi_scalar",
209280
constructor_arg_name="phi_scalar"),
281+
cirq.google.DeserializingArg(serialized_name="control_qubits",
282+
constructor_arg_name="control_qubits"),
283+
cirq.google.DeserializingArg(serialized_name="control_values",
284+
constructor_arg_name="control_values")
210285
]
211286
return cirq.google.GateOpDeserializer(serialized_gate_id="FSIM",
212287
gate_constructor=_scalar_combiner,
@@ -228,7 +303,14 @@ def _identity_check(x):
228303
args = [
229304
cirq.google.SerializingArg(serialized_name="unused",
230305
serialized_type=bool,
231-
op_getter=_identity_check)
306+
op_getter=_identity_check),
307+
cirq.google.SerializingArg(serialized_name="control_qubits",
308+
serialized_type=str,
309+
op_getter=lambda x: _serialize_controls(x)),
310+
cirq.google.SerializingArg(
311+
serialized_name="control_values",
312+
serialized_type=str,
313+
op_getter=lambda x: _serialize_control_vals(x))
232314
]
233315
return cirq.google.GateOpSerializer(gate_type=cirq.IdentityGate,
234316
serialized_gate_id="I",
@@ -240,11 +322,15 @@ def _identity_gate_deserializer():
240322
"""Make a standard deserializer for the single qubit identity."""
241323
args = [
242324
cirq.google.DeserializingArg(serialized_name="unused",
243-
constructor_arg_name="unused")
325+
constructor_arg_name="unused"),
326+
cirq.google.DeserializingArg(serialized_name="control_qubits",
327+
constructor_arg_name="control_qubits"),
328+
cirq.google.DeserializingArg(serialized_name="control_values",
329+
constructor_arg_name="control_values")
244330
]
245331

246-
def _cirq_i_workaround(unused):
247-
return cirq.I
332+
def _cirq_i_workaround(unused, control_qubits, control_values):
333+
return _optional_control_promote(cirq.I, control_qubits, control_values)
248334

249335
return cirq.google.GateOpDeserializer(serialized_gate_id="I",
250336
gate_constructor=_cirq_i_workaround,
@@ -274,7 +360,14 @@ def _phased_eigen_gate_serializer(gate_type, serialized_id):
274360
cirq.google.SerializingArg(
275361
serialized_name="global_shift",
276362
serialized_type=float,
277-
op_getter=lambda x: float(x.gate.global_shift))
363+
op_getter=lambda x: float(x.gate.global_shift)),
364+
cirq.google.SerializingArg(serialized_name="control_qubits",
365+
serialized_type=str,
366+
op_getter=lambda x: _serialize_controls(x)),
367+
cirq.google.SerializingArg(
368+
serialized_name="control_values",
369+
serialized_type=str,
370+
op_getter=lambda x: _serialize_control_vals(x))
278371
]
279372
return cirq.google.GateOpSerializer(gate_type=gate_type,
280373
serialized_gate_id=serialized_id,
@@ -286,7 +379,8 @@ def _phased_eigen_gate_deserializer(gate_type, serialized_id):
286379
"""Make a standard deserializer for phased eigen gates."""
287380

288381
def _scalar_combiner(exponent, global_shift, exponent_scalar,
289-
phase_exponent, phase_exponent_scalar):
382+
phase_exponent, phase_exponent_scalar, control_qubits,
383+
control_values):
290384
"""This is a workaround to support symbol scalar multiplication.
291385
In the future we should likely get rid of this in favor of proper
292386
expression parsing once cirq supports it. See cirq.op_serializer
@@ -302,10 +396,14 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
302396
if global_shift != 0:
303397
# needed in case this specific phasedeigengate doesn't
304398
# have a global_phase in constructor.
305-
return gate_type(exponent=exponent,
306-
global_shift=_round(global_shift),
307-
phase_exponent=phase_exponent)
308-
return gate_type(exponent=exponent, phase_exponent=phase_exponent)
399+
return _optional_control_promote(
400+
gate_type(exponent=exponent,
401+
global_shift=_round(global_shift),
402+
phase_exponent=phase_exponent), control_qubits,
403+
control_values)
404+
return _optional_control_promote(
405+
gate_type(exponent=exponent, phase_exponent=phase_exponent),
406+
control_qubits, control_values)
309407

310408
args = [
311409
cirq.google.DeserializingArg(serialized_name="phase_exponent",
@@ -319,6 +417,10 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
319417
constructor_arg_name="exponent_scalar"),
320418
cirq.google.DeserializingArg(serialized_name="global_shift",
321419
constructor_arg_name="global_shift"),
420+
cirq.google.DeserializingArg(serialized_name="control_qubits",
421+
constructor_arg_name="control_qubits"),
422+
cirq.google.DeserializingArg(serialized_name="control_values",
423+
constructor_arg_name="control_values")
322424
]
323425
return cirq.google.GateOpDeserializer(serialized_gate_id=serialized_id,
324426
gate_constructor=_scalar_combiner,
@@ -434,6 +536,21 @@ def serialize_circuit(circuit_inp):
434536
old_moment.operations))
435537
circuit[moment_ind] = new_moment
436538

539+
# Demote cirq.controlled_operations (controlled gates) to their sub_gate
540+
# types with _tfq_control_qubits and _tfq_control_values fields so that
541+
# the gates can still get picked up by the serializer. There would be no way
542+
# to discern controlledgates from one another otherwise. This
543+
# "momentary demotion" occurs with the help of the DelayedAssignmentGate.
544+
for i, moment in enumerate(circuit):
545+
for op in moment:
546+
if isinstance(op,
547+
cirq.ops.controlled_operation.ControlledOperation):
548+
tfq_compatible = op.sub_operation
549+
tfq_compatible._tfq_control_qubits = op.controls
550+
tfq_compatible._tfq_control_values = op.control_values
551+
dropped_moment = moment.without_operations_touching(op.qubits)
552+
circuit[i] = dropped_moment.with_operation(tfq_compatible)
553+
437554
return SERIALIZER.serialize(circuit)
438555

439556

0 commit comments

Comments
 (0)