Skip to content

Commit e4b00c7

Browse files
zhangskzJasonLunn
andauthored
Add support for extensions in CRuby, JRuby, and FFI Ruby (#14703) (#14756)
Follow up to #14594, which added support for custom options, this PR implements extensions support, which should fully resolve #1198. Closes #14703 COPYBARA_INTEGRATE_REVIEW=#14703 from protocolbuffers:add-support-for-extensions-in-ruby 601aca4 PiperOrigin-RevId: 582460674 Co-authored-by: Jason Lunn <[email protected]>
1 parent 2495d4f commit e4b00c7

File tree

13 files changed

+257
-55
lines changed

13 files changed

+257
-55
lines changed

ruby/ext/google/protobuf_c/defs.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,26 @@ VALUE DescriptorPool_add_serialized_file(VALUE _self,
144144
* call-seq:
145145
* DescriptorPool.lookup(name) => descriptor
146146
*
147-
* Finds a Descriptor or EnumDescriptor by name and returns it, or nil if none
148-
* exists with the given name.
147+
* Finds a Descriptor, EnumDescriptor or FieldDescriptor by name and returns it,
148+
* or nil if none exists with the given name.
149149
*/
150150
static VALUE DescriptorPool_lookup(VALUE _self, VALUE name) {
151151
DescriptorPool* self = ruby_to_DescriptorPool(_self);
152152
const char* name_str = get_str(name);
153153
const upb_MessageDef* msgdef;
154154
const upb_EnumDef* enumdef;
155+
const upb_FieldDef* fielddef;
155156

156157
msgdef = upb_DefPool_FindMessageByName(self->symtab, name_str);
157158
if (msgdef) {
158159
return get_msgdef_obj(_self, msgdef);
159160
}
160161

162+
fielddef = upb_DefPool_FindExtensionByName(self->symtab, name_str);
163+
if (fielddef) {
164+
return get_fielddef_obj(_self, fielddef);
165+
}
166+
161167
enumdef = upb_DefPool_FindEnumByName(self->symtab, name_str);
162168
if (enumdef) {
163169
return get_enumdef_obj(_self, enumdef);

ruby/ext/google/protobuf_c/message.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -977,9 +977,12 @@ VALUE Message_decode_bytes(int size, const char* bytes, int options,
977977
VALUE msg_rb = initialize_rb_class_with_no_args(klass);
978978
Message* msg = ruby_to_Message(msg_rb);
979979

980+
const upb_FileDef* file = upb_MessageDef_File(msg->msgdef);
981+
const upb_ExtensionRegistry* extreg =
982+
upb_DefPool_ExtensionRegistry(upb_FileDef_Pool(file));
980983
upb_DecodeStatus status = upb_Decode(bytes, size, (upb_Message*)msg->msg,
981984
upb_MessageDef_MiniTable(msg->msgdef),
982-
NULL, options, Arena_get(msg->arena));
985+
extreg, options, Arena_get(msg->arena));
983986
if (status != kUpb_DecodeStatus_Ok) {
984987
rb_raise(cParseError, "Error occurred during parsing");
985988
}
@@ -1303,9 +1306,12 @@ upb_Message* Message_deep_copy(const upb_Message* msg, const upb_MessageDef* m,
13031306
upb_Message* new_msg = upb_Message_New(layout, arena);
13041307
char* data;
13051308

1309+
const upb_FileDef* file = upb_MessageDef_File(m);
1310+
const upb_ExtensionRegistry* extreg =
1311+
upb_DefPool_ExtensionRegistry(upb_FileDef_Pool(file));
13061312
if (upb_Encode(msg, layout, 0, tmp_arena, &data, &size) !=
13071313
kUpb_EncodeStatus_Ok ||
1308-
upb_Decode(data, size, new_msg, layout, NULL, 0, arena) !=
1314+
upb_Decode(data, size, new_msg, layout, extreg, 0, arena) !=
13091315
kUpb_DecodeStatus_Ok) {
13101316
upb_Arena_Free(tmp_arena);
13111317
rb_raise(cParseError, "Error occurred copying proto");

ruby/lib/google/protobuf/ffi/descriptor_pool.rb

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ module Google
99
module Protobuf
1010
class FFI
1111
# DefPool
12-
attach_function :add_serialized_file, :upb_DefPool_AddFile, [:DefPool, :FileDescriptorProto, Status.by_ref], :FileDef
13-
attach_function :free_descriptor_pool, :upb_DefPool_Free, [:DefPool], :void
14-
attach_function :create_descriptor_pool,:upb_DefPool_New, [], :DefPool
15-
attach_function :lookup_enum, :upb_DefPool_FindEnumByName, [:DefPool, :string], EnumDescriptor
16-
attach_function :lookup_msg, :upb_DefPool_FindMessageByName, [:DefPool, :string], Descriptor
17-
# FileDescriptorProto
18-
attach_function :parse, :FileDescriptorProto_parse, [:binary_string, :size_t, Internal::Arena], :FileDescriptorProto
12+
attach_function :add_serialized_file, :upb_DefPool_AddFile, [:DefPool, :FileDescriptorProto, Status.by_ref], :FileDef
13+
attach_function :free_descriptor_pool, :upb_DefPool_Free, [:DefPool], :void
14+
attach_function :create_descriptor_pool,:upb_DefPool_New, [], :DefPool
15+
attach_function :get_extension_registry,:upb_DefPool_ExtensionRegistry, [:DefPool], :ExtensionRegistry
16+
attach_function :lookup_enum, :upb_DefPool_FindEnumByName, [:DefPool, :string], EnumDescriptor
17+
attach_function :lookup_extension, :upb_DefPool_FindExtensionByName,[:DefPool, :string], FieldDescriptor
18+
attach_function :lookup_msg, :upb_DefPool_FindMessageByName, [:DefPool, :string], Descriptor
19+
20+
# FileDescriptorProto
21+
attach_function :parse, :FileDescriptorProto_parse, [:binary_string, :size_t, Internal::Arena], :FileDescriptorProto
1922
end
2023
class DescriptorPool
2124
attr :descriptor_pool
@@ -50,7 +53,8 @@ def add_serialized_file(file_contents)
5053

5154
def lookup name
5255
Google::Protobuf::FFI.lookup_msg(@descriptor_pool, name) ||
53-
Google::Protobuf::FFI.lookup_enum(@descriptor_pool, name)
56+
Google::Protobuf::FFI.lookup_enum(@descriptor_pool, name) ||
57+
Google::Protobuf::FFI.lookup_extension(@descriptor_pool, name)
5458
end
5559

5660
def self.generated_pool

ruby/lib/google/protobuf/ffi/message.rb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,15 @@ def self.decode(data, options = {})
170170

171171
message = new
172172
mini_table_ptr = Google::Protobuf::FFI.get_mini_table(message.class.descriptor)
173-
status = Google::Protobuf::FFI.decode_message(data, data.bytesize, message.instance_variable_get(:@msg), mini_table_ptr, nil, decoding_options, message.instance_variable_get(:@arena))
173+
status = Google::Protobuf::FFI.decode_message(
174+
data,
175+
data.bytesize,
176+
message.instance_variable_get(:@msg),
177+
mini_table_ptr,
178+
Google::Protobuf::FFI.get_extension_registry(message.class.descriptor.send(:pool).descriptor_pool),
179+
decoding_options,
180+
message.instance_variable_get(:@arena)
181+
)
174182
raise ParseError.new "Error occurred during parsing" unless status == :Ok
175183
message
176184
end

ruby/src/main/java/com/google/protobuf/jruby/RubyDescriptorPool.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
import com.google.protobuf.Descriptors.Descriptor;
3737
import com.google.protobuf.Descriptors.DescriptorValidationException;
3838
import com.google.protobuf.Descriptors.EnumDescriptor;
39+
import com.google.protobuf.Descriptors.FieldDescriptor;
3940
import com.google.protobuf.Descriptors.FileDescriptor;
41+
import com.google.protobuf.ExtensionRegistry;
4042
import com.google.protobuf.InvalidProtocolBufferException;
4143
import java.util.ArrayList;
4244
import java.util.HashMap;
@@ -70,6 +72,7 @@ public IRubyObject allocate(Ruby runtime, RubyClass klazz) {
7072
cDescriptorPool.newInstance(runtime.getCurrentContext(), Block.NULL_BLOCK);
7173
cDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::Descriptor");
7274
cEnumDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::EnumDescriptor");
75+
cFieldDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::FieldDescriptor");
7376
}
7477

7578
public RubyDescriptorPool(Ruby runtime, RubyClass klazz) {
@@ -92,7 +95,7 @@ public IRubyObject build(ThreadContext context, Block block) {
9295
* call-seq:
9396
* DescriptorPool.lookup(name) => descriptor
9497
*
95-
* Finds a Descriptor or EnumDescriptor by name and returns it, or nil if none
98+
* Finds a Descriptor, EnumDescriptor or FieldDescriptor by name and returns it, or nil if none
9699
* exists with the given name.
97100
*
98101
* This currently lazy loads the ruby descriptor objects as they are requested.
@@ -121,7 +124,8 @@ public static IRubyObject generatedPool(ThreadContext context, IRubyObject recv)
121124
public IRubyObject add_serialized_file(ThreadContext context, IRubyObject data) {
122125
byte[] bin = data.convertToString().getBytes();
123126
try {
124-
FileDescriptorProto.Builder builder = FileDescriptorProto.newBuilder().mergeFrom(bin);
127+
FileDescriptorProto.Builder builder =
128+
FileDescriptorProto.newBuilder().mergeFrom(bin, registry);
125129
registerFileDescriptor(context, builder);
126130
} catch (InvalidProtocolBufferException e) {
127131
throw RaiseException.from(
@@ -150,6 +154,8 @@ protected void registerFileDescriptor(
150154
for (EnumDescriptor ed : fd.getEnumTypes()) registerEnumDescriptor(context, ed, packageName);
151155
for (Descriptor message : fd.getMessageTypes())
152156
registerDescriptor(context, message, packageName);
157+
for (FieldDescriptor fieldDescriptor : fd.getExtensions())
158+
registerExtension(context, fieldDescriptor, packageName);
153159

154160
// Mark this as a loaded file
155161
fileDescriptors.add(fd);
@@ -170,6 +176,24 @@ private void registerDescriptor(ThreadContext context, Descriptor descriptor, St
170176
registerEnumDescriptor(context, ed, fullPath);
171177
for (Descriptor message : descriptor.getNestedTypes())
172178
registerDescriptor(context, message, fullPath);
179+
for (FieldDescriptor fieldDescriptor : descriptor.getExtensions())
180+
registerExtension(context, fieldDescriptor, fullPath);
181+
}
182+
183+
private void registerExtension(
184+
ThreadContext context, FieldDescriptor descriptor, String parentPath) {
185+
if (descriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
186+
registry.add(descriptor, descriptor.toProto());
187+
} else {
188+
registry.add(descriptor);
189+
}
190+
RubyString name = context.runtime.newString(parentPath + descriptor.getName());
191+
RubyFieldDescriptor des =
192+
(RubyFieldDescriptor) cFieldDescriptor.newInstance(context, Block.NULL_BLOCK);
193+
des.setName(name);
194+
des.setDescriptor(context, descriptor, this);
195+
// For MessageSet extensions, there is the possibility of a name conflict. Prefer the Message.
196+
symtab.putIfAbsent(name, des);
173197
}
174198

175199
private void registerEnumDescriptor(
@@ -188,8 +212,10 @@ private FileDescriptor[] existingFileDescriptors() {
188212

189213
private static RubyClass cDescriptor;
190214
private static RubyClass cEnumDescriptor;
215+
private static RubyClass cFieldDescriptor;
191216
private static RubyDescriptorPool descriptorPool;
192217

193218
private List<FileDescriptor> fileDescriptors;
194219
private Map<IRubyObject, IRubyObject> symtab;
220+
protected static final ExtensionRegistry registry = ExtensionRegistry.newInstance();
195221
}

ruby/src/main/java/com/google/protobuf/jruby/RubyFieldDescriptor.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ public IRubyObject getName(ThreadContext context) {
103103
return this.name;
104104
}
105105

106+
protected void setName(IRubyObject name) {
107+
this.name = name;
108+
}
109+
106110
/*
107111
* call-seq:
108112
* FieldDescriptor.subtype => message_or_enum_descriptor
@@ -229,7 +233,7 @@ public IRubyObject has(ThreadContext context, IRubyObject message) {
229233
*/
230234
@JRubyMethod(name = "set")
231235
public IRubyObject setValue(ThreadContext context, IRubyObject message, IRubyObject value) {
232-
((RubyMessage) message).setField(context, descriptor, value);
236+
((RubyMessage) message).setField(context, this, value);
233237
return context.nil;
234238
}
235239

@@ -263,6 +267,10 @@ protected void setDescriptor(
263267
this.pool = pool;
264268
}
265269

270+
protected FieldDescriptor getDescriptor() {
271+
return descriptor;
272+
}
273+
266274
private void calculateLabel(ThreadContext context) {
267275
if (descriptor.isRepeated()) {
268276
this.label = context.runtime.newSymbol("repeated");

ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyO
634634
public static IRubyObject decodeBytes(
635635
ThreadContext context, RubyMessage ret, CodedInputStream input, boolean freeze) {
636636
try {
637-
ret.builder.mergeFrom(input);
637+
ret.builder.mergeFrom(input, RubyDescriptorPool.registry);
638638
} catch (Exception e) {
639639
throw RaiseException.from(
640640
context.runtime,
@@ -965,6 +965,12 @@ protected IRubyObject setField(
965965
return setFieldInternal(context, fieldDescriptor, value);
966966
}
967967

968+
protected IRubyObject setField(
969+
ThreadContext context, RubyFieldDescriptor fieldDescriptor, IRubyObject value) {
970+
validateMessageType(context, fieldDescriptor.getDescriptor(), "set");
971+
return setFieldInternal(context, fieldDescriptor.getDescriptor(), fieldDescriptor, value);
972+
}
973+
968974
private RubyRepeatedField getRepeatedField(
969975
ThreadContext context, FieldDescriptor fieldDescriptor) {
970976
if (fields.containsKey(fieldDescriptor)) {
@@ -1275,6 +1281,14 @@ private IRubyObject getFieldInternal(
12751281

12761282
private IRubyObject setFieldInternal(
12771283
ThreadContext context, FieldDescriptor fieldDescriptor, IRubyObject value) {
1284+
return setFieldInternal(context, fieldDescriptor, null, value);
1285+
}
1286+
1287+
private IRubyObject setFieldInternal(
1288+
ThreadContext context,
1289+
FieldDescriptor fieldDescriptor,
1290+
RubyFieldDescriptor rubyFieldDescriptor,
1291+
IRubyObject value) {
12781292
testFrozen("can't modify frozen " + getMetaClass());
12791293

12801294
if (fieldDescriptor.isMapField()) {
@@ -1299,8 +1313,12 @@ private IRubyObject setFieldInternal(
12991313
// Determine the typeclass, if any
13001314
IRubyObject typeClass = context.runtime.getObject();
13011315
if (fieldType == FieldDescriptor.Type.MESSAGE) {
1302-
typeClass =
1303-
((RubyDescriptor) getDescriptorForField(context, fieldDescriptor)).msgclass(context);
1316+
if (rubyFieldDescriptor != null) {
1317+
typeClass = ((RubyDescriptor) rubyFieldDescriptor.getSubtype(context)).msgclass(context);
1318+
} else {
1319+
typeClass =
1320+
((RubyDescriptor) getDescriptorForField(context, fieldDescriptor)).msgclass(context);
1321+
}
13041322
if (value.isNil()) {
13051323
addValue = false;
13061324
}

ruby/tests/basic.rb

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,19 @@ def test_oneof_descriptor_options
729729
oneof_descriptor = descriptor.lookup_oneof("test_deprecated_message_oneof")
730730

731731
assert_instance_of Google::Protobuf::OneofOptions, oneof_descriptor.options
732+
test_top_level_option = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test.test_top_level_option'
733+
assert_instance_of Google::Protobuf::FieldDescriptor, test_top_level_option
734+
assert_equal "Custom option value", test_top_level_option.get(oneof_descriptor.options)
735+
end
736+
737+
def test_nested_extension
738+
descriptor = TestDeprecatedMessage.descriptor
739+
oneof_descriptor = descriptor.lookup_oneof("test_deprecated_message_oneof")
740+
741+
assert_instance_of Google::Protobuf::OneofOptions, oneof_descriptor.options
742+
test_nested_option = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test.TestDeprecatedMessage.test_nested_option'
743+
assert_instance_of Google::Protobuf::FieldDescriptor, test_nested_option
744+
assert_equal "Another custom option value", test_nested_option.get(oneof_descriptor.options)
732745
end
733746

734747
def test_options_deep_freeze
@@ -739,6 +752,25 @@ def test_options_deep_freeze
739752
Google::Protobuf::UninterpretedOption.new
740753
end
741754
end
755+
756+
def test_message_deep_freeze
757+
message = TestDeprecatedMessage.new
758+
omit(":internal_deep_freeze only exists under FFI") unless message.respond_to? :internal_deep_freeze, true
759+
nested_message_2 = TestMessage2.new
760+
761+
message.map_string_msg["message"] = TestMessage2.new
762+
message.repeated_msg.push(TestMessage2.new)
763+
764+
message.send(:internal_deep_freeze)
765+
766+
assert_raise FrozenError do
767+
message.map_string_msg["message"].foo = "bar"
768+
end
769+
770+
assert_raise FrozenError do
771+
message.repeated_msg[0].foo = "bar"
772+
end
773+
end
742774
end
743775

744776
def test_oneof_fields_respond_to? # regression test for issue 9202

ruby/tests/basic_proto2.rb

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,5 +269,57 @@ def test_oneof_fields_respond_to? # regression test for issue 9202
269269
assert msg.respond_to? :has_d?
270270
refute msg.has_d?
271271
end
272+
273+
def test_extension
274+
message = TestExtensions.new
275+
extension = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.optional_int32_extension'
276+
assert_instance_of Google::Protobuf::FieldDescriptor, extension
277+
assert_equal 0, extension.get(message)
278+
extension.set message, 42
279+
assert_equal 42, extension.get(message)
280+
end
281+
282+
def test_nested_extension
283+
message = TestExtensions.new
284+
extension = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestNestedExtension.test'
285+
assert_instance_of Google::Protobuf::FieldDescriptor, extension
286+
assert_equal 'test', extension.get(message)
287+
extension.set message, 'another test'
288+
assert_equal 'another test', extension.get(message)
289+
end
290+
291+
def test_message_set_extension_json_roundtrip
292+
omit "Java Protobuf JsonFormat does not handle Proto2 extensions" if defined? JRUBY_VERSION and :NATIVE == Google::Protobuf::IMPLEMENTATION
293+
message = TestMessageSet.new
294+
ext1 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension1.message_set_extension'
295+
assert_instance_of Google::Protobuf::FieldDescriptor, ext1
296+
ext2 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension2.message_set_extension'
297+
assert_instance_of Google::Protobuf::FieldDescriptor, ext2
298+
ext3 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.message_set_extension3'
299+
assert_instance_of Google::Protobuf::FieldDescriptor, ext3
300+
ext1.set(message, ext1.subtype.msgclass.new(i: 42))
301+
ext2.set(message, ext2.subtype.msgclass.new(str: 'foo'))
302+
ext3.set(message, ext3.subtype.msgclass.new(text: 'bar'))
303+
message_text = message.to_json
304+
parsed_message = TestMessageSet.decode_json message_text
305+
assert_equal message, parsed_message
306+
end
307+
308+
309+
def test_message_set_extension_roundtrip
310+
message = TestMessageSet.new
311+
ext1 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension1.message_set_extension'
312+
assert_instance_of Google::Protobuf::FieldDescriptor, ext1
313+
ext2 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension2.message_set_extension'
314+
assert_instance_of Google::Protobuf::FieldDescriptor, ext2
315+
ext3 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.message_set_extension3'
316+
assert_instance_of Google::Protobuf::FieldDescriptor, ext3
317+
ext1.set(message, ext1.subtype.msgclass.new(i: 42))
318+
ext2.set(message, ext2.subtype.msgclass.new(str: 'foo'))
319+
ext3.set(message, ext3.subtype.msgclass.new(text: 'bar'))
320+
encoded_message = TestMessageSet.encode message
321+
decoded_message = TestMessageSet.decode encoded_message
322+
assert_equal message, decoded_message
323+
end
272324
end
273325
end

0 commit comments

Comments
 (0)