Skip to content

Commit f96628c

Browse files
committed
make golang interface pass tests
1 parent 3fc02e5 commit f96628c

File tree

4 files changed

+84
-27
lines changed

4 files changed

+84
-27
lines changed

src/lpg/interfaces/file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66

77
from .lang.base import BaseLanguageInterface
88
from .lang.c import CLanguageInterface
9+
from .lang.golang import GoLanguageInterface
910
from .lang.java import JavaLanguageInterface
1011
from .lang.javascript import JavaScriptLanguageInterface
1112
from .lang.python import PythonLanguageInterface
1213
from .lang.python3 import Python3LanguageInterface
1314
from .lang.typescript import TypeScriptLanguageInterface
14-
from .lang.golang import GoLanguageInterface
1515

1616
LANGUAGE_INTERFACES: dict[str, BaseLanguageInterface] = {
1717
"c": CLanguageInterface(),
18-
"g": GoLanguageInterface(),
18+
"golang": GoLanguageInterface(),
1919
"java": JavaLanguageInterface(),
2020
"javascript": JavaScriptLanguageInterface(),
2121
"typescript": TypeScriptLanguageInterface(),

src/lpg/interfaces/lang/base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,28 @@ def get_supplemental_code(self, template: str) -> str | None:
7474
if match is not None
7575
)
7676

77+
def is_void_return_type(self) -> bool:
78+
"""Determines if the solution function return type is void. Can be overridden."""
79+
return self.groups["returnType"] == "void"
80+
7781
def get_formatted_nonvoid_template(
78-
self, template: str, nonvoid_callback: Callable[[], str]
82+
self,
83+
template: str,
84+
nonvoid_callback: Callable[[], str],
85+
result_var_declaration: str | None = None,
7986
) -> str:
8087
"""Adjusts the return type and method call when the return type is void.
8188
Useful for C-style languages where assigning to a void variable is not allowed.
8289
"""
83-
if self.groups["returnType"] == "void":
90+
if self.is_void_return_type():
8491
self.groups["result_var_declaration"] = ""
8592
self.groups["result_var"] = "0"
8693
return template
87-
self.groups["result_var_declaration"] = f"{self.groups['returnType']} result = "
94+
self.groups["result_var_declaration"] = (
95+
f"{self.groups['returnType']} result = "
96+
if result_var_declaration is None
97+
else result_var_declaration
98+
)
8899
self.groups["result_var"] = "result"
89100
return nonvoid_callback()
90101

src/lpg/interfaces/lang/golang.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Project generator for the Go language."""
22

33
import re
4+
45
from .base import BaseLanguageInterface
56

67
# Go function signature pattern
@@ -9,44 +10,68 @@
910
flags=re.MULTILINE,
1011
)
1112

12-
SOLUTION_FILE_TEMPLATE = """\
13-
package solution
14-
15-
{supplemental_code}
16-
func {name}({params}) {returnType} {
17-
// TODO: Implement solution
18-
{return_statement}
19-
}
20-
"""
2113

2214
TEST_FILE_TEMPLATE = """\
2315
package main
2416
2517
import (
2618
"fmt"
27-
"./solution"
2819
)
2920
30-
func main() {
21+
func main() {{
3122
// Test case setup
3223
{params_setup}
3324
3425
// Execute solution
35-
{result_var_declaration}solution.{name}({params_call})
26+
{result_var_declaration}{name}({params_call})
3627
3728
// Display result
3829
fmt.Printf("{OUTPUT_RESULT_PREFIX} %v\\n", {result_var})
39-
}
30+
}}
4031
"""
32+
SOLUTION_REPLACEMENT_PATTERN = re.compile(r"\n}")
33+
SOLUTION_REPLACEMENT_TEMPLATE = "{return_statement}\n}}"
4134

4235

4336
class GoLanguageInterface(BaseLanguageInterface):
4437
"""Implementation of the Go language project template interface."""
4538

4639
function_signature_pattern = FUNCTION_SIGNATURE_PATTERN
47-
compile_command = ["go", "build", "-o", "test", "test.go"]
4840
test_command = ["./test"]
49-
default_output = "0"
41+
42+
@property
43+
def compile_command(self):
44+
args = ["go", "build", "-o", "test", "test.go", "solution.go"]
45+
46+
supplemental_filename = self.get_supplemental_filename()
47+
if supplemental_filename is not None:
48+
args.append(supplemental_filename)
49+
50+
return args
51+
52+
def get_supplemental_filename(self):
53+
"""Obtains the name of the supplemental file."""
54+
return "extra.go" if self.groups["supplemental_code"] else None
55+
56+
@property
57+
def default_output(self):
58+
59+
if self.is_void_return_type():
60+
return "0"
61+
if "[]" in self.groups["returnType"]:
62+
return "[]"
63+
output = self._get_default_value(self.groups["returnType"])
64+
# Converts the Go type to its string representation
65+
match output:
66+
case '""':
67+
return ""
68+
case "nil":
69+
return "<nil>"
70+
case _:
71+
return output
72+
73+
def is_void_return_type(self):
74+
return self.groups["returnType"].strip() == ""
5075

5176
def prepare_project_files(self, template: str):
5277
params = self.groups["params"].split(", ")
@@ -68,7 +93,7 @@ def prepare_project_files(self, template: str):
6893

6994
self.groups["params_setup"] = ";\n ".join(
7095
[
71-
f"{name} := {self._get_default_value(param_type)}"
96+
self._get_variable_declaration(name, param_type)
7297
for name, param_type in zip(param_names, param_types)
7398
]
7499
)
@@ -80,13 +105,31 @@ def prepare_project_files(self, template: str):
80105

81106
# Handle non-void return types
82107
formatted_template = self.get_formatted_nonvoid_template(
83-
TEST_FILE_TEMPLATE, lambda: TEST_FILE_TEMPLATE
108+
template,
109+
lambda: re.sub(
110+
SOLUTION_REPLACEMENT_PATTERN,
111+
(SOLUTION_REPLACEMENT_TEMPLATE.format(**self.groups)),
112+
template,
113+
),
114+
"result := ",
84115
)
85116

86-
return {
87-
"solution/solution.go": SOLUTION_FILE_TEMPLATE.format(**self.groups),
88-
"test.go": formatted_template.format(**self.groups),
117+
project_files = {
118+
"solution.go": f"package main\n\n{formatted_template}",
119+
"test.go": TEST_FILE_TEMPLATE.format(**self.groups),
89120
}
121+
if self.groups["supplemental_code"]:
122+
filename = self.get_supplemental_filename()
123+
project_files[filename] = (
124+
f"package main\n\n{self.groups["supplemental_code"]}"
125+
)
126+
return project_files
127+
128+
def _get_variable_declaration(self, name: str, variable_type: str) -> str:
129+
default_value = self._get_default_value(variable_type)
130+
if default_value == "nil":
131+
return f"var {name} {variable_type}"
132+
return f"{name} := {default_value}"
90133

91134
def _get_default_value(self, param_type: str) -> str:
92135
"""Returns a default value for the given Go type."""
@@ -100,8 +143,11 @@ def _get_default_value(self, param_type: str) -> str:
100143
"bool": "false",
101144
}
102145

146+
if "[]" in param_type:
147+
return f"{param_type}{{}}" # Default slice initialization
148+
103149
# Check for specific patterns
104-
if "[]" in param_type or "map" in param_type or "*" in param_type:
150+
if "map" in param_type or "*" in param_type:
105151
return "nil"
106152

107153
# Check for simple matches in dictionary

src/lpg/interfaces/lang_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def conduct_tests(self, code_snippets: list[dict[str, str]]):
7272
continue
7373
self.assertMultiLineEqual(
7474
result.stdout.decode().strip(),
75-
f"{OUTPUT_RESULT_PREFIX} {interface.default_output}",
75+
f"{OUTPUT_RESULT_PREFIX} {interface.default_output}".strip(),
7676
)
7777

7878

0 commit comments

Comments
 (0)