Skip to content

Commit 0123565

Browse files
Merge pull request #56 from gleanwork/cfreeman/fix-multipart-form-bug
fix: implement Speakeasy hook to fix multipart file field names
2 parents 44b18d6 + 643f273 commit 0123565

File tree

3 files changed

+299
-0
lines changed

3 files changed

+299
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Hook to fix multipart form file field names that incorrectly have '[]' suffix."""
2+
3+
from typing import Any, Dict, List, Tuple
4+
from .types import SDKInitHook
5+
from glean.api_client.httpclient import HttpClient
6+
from glean.api_client.utils import forms
7+
8+
9+
class MultipartFileFieldFixHook(SDKInitHook):
10+
"""
11+
Fixes multipart form serialization where file field names incorrectly have '[]' suffix.
12+
13+
Speakeasy sometimes generates code that adds '[]' to file field names in multipart forms,
14+
but this is incorrect. File fields should not have the array suffix, only regular form
15+
fields should use this convention.
16+
17+
This hook patches the serialize_multipart_form function to fix the issue at the source.
18+
"""
19+
20+
def sdk_init(self, base_url: str, client: HttpClient) -> Tuple[str, HttpClient]:
21+
"""Initialize the SDK and patch the multipart form serialization."""
22+
self._patch_multipart_serialization()
23+
return base_url, client
24+
25+
def _patch_multipart_serialization(self):
26+
"""Patch the serialize_multipart_form function to fix file field names."""
27+
# Store reference to original function
28+
original_serialize_multipart_form = forms.serialize_multipart_form
29+
30+
def fixed_serialize_multipart_form(
31+
media_type: str, request: Any
32+
) -> Tuple[str, Dict[str, Any], List[Tuple[str, Any]]]:
33+
"""Fixed version of serialize_multipart_form that doesn't add '[]' to file field names."""
34+
# Call the original function
35+
result_media_type, form_data, files_list = (
36+
original_serialize_multipart_form(media_type, request)
37+
)
38+
39+
# Fix file field names in the files list
40+
fixed_files = []
41+
for item in files_list:
42+
if isinstance(item, tuple) and len(item) >= 2:
43+
field_name = item[0]
44+
file_data = item[1]
45+
46+
# Remove '[]' suffix from file field names only
47+
# We can identify file fields by checking if the data looks like file content
48+
if field_name.endswith("[]") and self._is_file_field_data(
49+
file_data
50+
):
51+
fixed_field_name = field_name[:-2] # Remove '[]' suffix
52+
fixed_item = (fixed_field_name,) + item[1:]
53+
fixed_files.append(fixed_item)
54+
else:
55+
fixed_files.append(item)
56+
else:
57+
fixed_files.append(item)
58+
59+
return result_media_type, form_data, fixed_files
60+
61+
# Replace the original function with our fixed version
62+
forms.serialize_multipart_form = fixed_serialize_multipart_form
63+
64+
def _is_file_field_data(self, file_data: Any) -> bool:
65+
"""
66+
Determine if the data represents file field content.
67+
68+
File fields typically have tuple format: (filename, content) or (filename, content, content_type)
69+
where content is bytes, file-like object, or similar.
70+
"""
71+
if isinstance(file_data, tuple) and len(file_data) >= 2:
72+
# Check the structure: (filename, content, [optional content_type])
73+
filename = file_data[0]
74+
content = file_data[1]
75+
76+
# If filename is empty, this is likely JSON content, not a file
77+
if filename == "":
78+
return False
79+
80+
# File content is typically bytes, string, or file-like object
81+
# But exclude empty strings and None values
82+
if content is None or content == "":
83+
return False
84+
85+
return (
86+
isinstance(content, (bytes, str))
87+
or hasattr(content, "read") # File-like object
88+
or (
89+
hasattr(content, "__iter__") and not isinstance(content, str)
90+
) # Iterable but not string
91+
)
92+
return False

src/glean/api_client/_hooks/registration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .types import Hooks
2+
from .multipart_fix_hook import MultipartFileFieldFixHook
23

34

45
# This file is only ever generated once on the first generation and then is free to be modified.
@@ -11,3 +12,6 @@ def init_hooks(hooks: Hooks):
1112
"""Add hooks by calling hooks.register{sdk_init/before_request/after_success/after_error}Hook
1213
with an instance of a hook that implements that specific Hook interface
1314
Hooks are registered per SDK instance, and are valid for the lifetime of the SDK instance"""
15+
16+
# Register hook to fix multipart file field names that incorrectly have '[]' suffix
17+
hooks.register_sdk_init_hook(MultipartFileFieldFixHook())

tests/test_multipart_fix_hook.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Test for the multipart file field fix hook."""
2+
3+
from unittest.mock import Mock, patch
4+
5+
import pytest
6+
7+
from src.glean.api_client._hooks.multipart_fix_hook import MultipartFileFieldFixHook
8+
from src.glean.api_client.httpclient import HttpClient
9+
10+
11+
class TestMultipartFileFieldFixHook:
12+
"""Test cases for the MultipartFileFieldFixHook."""
13+
14+
def setup_method(self):
15+
"""Set up test fixtures."""
16+
self.hook = MultipartFileFieldFixHook()
17+
self.mock_client = Mock(spec=HttpClient)
18+
19+
def test_sdk_init_returns_unchanged_params(self):
20+
"""Test that SDK init returns the same base_url and client."""
21+
base_url = "https://api.example.com"
22+
23+
with patch.object(self.hook, "_patch_multipart_serialization"):
24+
result_url, result_client = self.hook.sdk_init(base_url, self.mock_client)
25+
26+
assert result_url == base_url
27+
assert result_client == self.mock_client
28+
29+
def test_sdk_init_calls_patch_function(self):
30+
"""Test that SDK init calls the patch function."""
31+
base_url = "https://api.example.com"
32+
33+
with patch.object(self.hook, "_patch_multipart_serialization") as mock_patch:
34+
self.hook.sdk_init(base_url, self.mock_client)
35+
mock_patch.assert_called_once()
36+
37+
def test_is_file_field_data_identifies_file_content(self):
38+
"""Test the file field data identification logic."""
39+
# Test file field formats
40+
assert self.hook._is_file_field_data(("test.txt", b"content"))
41+
assert self.hook._is_file_field_data(("test.txt", b"content", "text/plain"))
42+
assert self.hook._is_file_field_data(("test.txt", "string content"))
43+
44+
# Test with file-like object
45+
mock_file = Mock()
46+
mock_file.read = Mock()
47+
assert self.hook._is_file_field_data(("test.txt", mock_file))
48+
49+
# Test non-file field formats
50+
assert not self.hook._is_file_field_data("regular_value")
51+
assert not self.hook._is_file_field_data(123)
52+
assert not self.hook._is_file_field_data(("single_item",))
53+
assert not self.hook._is_file_field_data((None, None))
54+
55+
@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
56+
def test_patch_multipart_serialization_replaces_function(self, mock_forms_module):
57+
"""Test that the patching replaces the serialize_multipart_form function."""
58+
# Mock the original function
59+
original_function = Mock()
60+
mock_forms_module.serialize_multipart_form = original_function
61+
62+
# Call the patch method
63+
self.hook._patch_multipart_serialization()
64+
65+
# Verify that the function was replaced
66+
assert mock_forms_module.serialize_multipart_form != original_function
67+
68+
@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
69+
def test_patched_function_fixes_file_field_names(self, mock_forms_module):
70+
"""Test that the patched function correctly fixes file field names."""
71+
# Mock original function to return data with '[]' suffix
72+
original_function = Mock()
73+
original_function.return_value = (
74+
"multipart/form-data",
75+
{"regular_field": "value"},
76+
[
77+
("file[]", ("test.txt", b"file content", "text/plain")),
78+
("documents[]", ("doc.pdf", b"pdf content", "application/pdf")),
79+
("regular_array[]", "regular_value"), # This should not be changed
80+
],
81+
)
82+
mock_forms_module.serialize_multipart_form = original_function
83+
84+
# Apply the patch
85+
self.hook._patch_multipart_serialization()
86+
87+
# Get the patched function
88+
patched_function = mock_forms_module.serialize_multipart_form
89+
90+
# Call the patched function
91+
media_type, form_data, files_list = patched_function(
92+
"multipart/form-data", Mock()
93+
)
94+
95+
# Verify the results
96+
assert media_type == "multipart/form-data"
97+
assert form_data == {"regular_field": "value"}
98+
99+
# Check that file field names are fixed but regular fields are not
100+
expected_files = [
101+
("file", ("test.txt", b"file content", "text/plain")),
102+
("documents", ("doc.pdf", b"pdf content", "application/pdf")),
103+
("regular_array[]", "regular_value"), # Should remain unchanged
104+
]
105+
assert files_list == expected_files
106+
107+
@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
108+
def test_patched_function_preserves_correct_names(self, mock_forms_module):
109+
"""Test that the patched function preserves already correct field names."""
110+
# Mock original function to return data without '[]' suffix
111+
original_function = Mock()
112+
original_function.return_value = (
113+
"multipart/form-data",
114+
{},
115+
[
116+
("file", ("test.txt", b"file content", "text/plain")),
117+
("document", ("doc.pdf", b"pdf content", "application/pdf")),
118+
],
119+
)
120+
mock_forms_module.serialize_multipart_form = original_function
121+
122+
# Apply the patch
123+
self.hook._patch_multipart_serialization()
124+
125+
# Get the patched function
126+
patched_function = mock_forms_module.serialize_multipart_form
127+
128+
# Call the patched function
129+
media_type, form_data, files_list = patched_function(
130+
"multipart/form-data", Mock()
131+
)
132+
133+
# Verify that nothing was changed
134+
expected_files = [
135+
("file", ("test.txt", b"file content", "text/plain")),
136+
("document", ("doc.pdf", b"pdf content", "application/pdf")),
137+
]
138+
assert files_list == expected_files
139+
140+
@patch("src.glean.api_client._hooks.multipart_fix_hook.forms")
141+
def test_patched_function_handles_mixed_fields(self, mock_forms_module):
142+
"""Test handling of mixed file and non-file fields."""
143+
# Mock original function with mixed field types
144+
original_function = Mock()
145+
original_function.return_value = (
146+
"multipart/form-data",
147+
{"form_field": "value"},
148+
[
149+
("correct_file", ("test1.txt", b"content1", "text/plain")),
150+
("wrong_file[]", ("test2.txt", b"content2", "text/plain")),
151+
("form_array[]", "form_value"), # Regular form field, should keep []
152+
(
153+
"json_field[]",
154+
("", '{"key": "value"}', "application/json"),
155+
), # JSON field, might need []
156+
],
157+
)
158+
mock_forms_module.serialize_multipart_form = original_function
159+
160+
# Apply the patch
161+
self.hook._patch_multipart_serialization()
162+
163+
# Get the patched function
164+
patched_function = mock_forms_module.serialize_multipart_form
165+
166+
# Call the patched function
167+
media_type, form_data, files_list = patched_function(
168+
"multipart/form-data", Mock()
169+
)
170+
171+
# Verify the results - only actual file fields should have [] removed
172+
expected_files = [
173+
("correct_file", ("test1.txt", b"content1", "text/plain")),
174+
("wrong_file", ("test2.txt", b"content2", "text/plain")), # Fixed
175+
("form_array[]", "form_value"), # Preserved - not a file field
176+
(
177+
"json_field[]",
178+
("", '{"key": "value"}', "application/json"),
179+
), # Preserved - JSON content
180+
]
181+
assert files_list == expected_files
182+
183+
def test_file_field_detection_edge_cases(self):
184+
"""Test edge cases for file field detection."""
185+
# Empty content
186+
assert not self.hook._is_file_field_data(("test.txt", ""))
187+
188+
# None content
189+
assert not self.hook._is_file_field_data(("test.txt", None))
190+
191+
# List/tuple content (should be considered file-like)
192+
assert self.hook._is_file_field_data(("test.txt", [1, 2, 3]))
193+
assert self.hook._is_file_field_data(("test.txt", (1, 2, 3)))
194+
195+
# String content (should be considered file content)
196+
assert self.hook._is_file_field_data(("test.txt", "string content"))
197+
198+
# But not if it's the first element
199+
assert not self.hook._is_file_field_data(("string content",))
200+
201+
202+
if __name__ == "__main__":
203+
pytest.main([__file__])

0 commit comments

Comments
 (0)