Skip to content

Commit 43c5098

Browse files
authored
Rewrite version result to fix column name (#179)
* Rewrite version result to fix column name When `select version()` is run, return a column named `version` instead of `version()` so we are more compatible with postgresql behaviour. This is relied upon by some clients, including pgadmin4. * Fix clippy warning
1 parent 2ebd6ec commit 43c5098

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

datafusion-postgres/src/sql/parser.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use datafusion::sql::sqlparser::parser::ParserError;
77
use datafusion::sql::sqlparser::tokenizer::Token;
88
use datafusion::sql::sqlparser::tokenizer::TokenWithSpan;
99

10+
use crate::sql::rules::FixVersionColumnName;
11+
1012
use super::rules::AliasDuplicatedProjectionRewrite;
1113
use super::rules::CurrentUserVariableToSessionUserFunctionCall;
1214
use super::rules::FixArrayLiteral;
@@ -212,6 +214,7 @@ impl PostgresCompatibilityParser {
212214
Arc::new(CurrentUserVariableToSessionUserFunctionCall),
213215
Arc::new(FixCollate),
214216
Arc::new(RemoveSubqueryFromProjection),
217+
Arc::new(FixVersionColumnName),
215218
],
216219
}
217220
}

datafusion-postgres/src/sql/rules.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,50 @@ impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
770770
}
771771
}
772772

773+
/// `select version()` should return column named `version` not `version()`
774+
#[derive(Debug)]
775+
pub struct FixVersionColumnName;
776+
777+
struct FixVersionColumnNameVisitor;
778+
779+
impl VisitorMut for FixVersionColumnNameVisitor {
780+
type Break = ();
781+
782+
fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
783+
if let SetExpr::Select(select) = query.body.as_mut() {
784+
for projection in &mut select.projection {
785+
if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
786+
if f.name.0.len() == 1 {
787+
if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
788+
if part.value == "version" {
789+
if let FunctionArguments::List(args) = &f.args {
790+
if args.args.is_empty() {
791+
*projection = SelectItem::ExprWithAlias {
792+
expr: Expr::Function(f.clone()),
793+
alias: Ident::new("version"),
794+
}
795+
}
796+
}
797+
}
798+
}
799+
}
800+
}
801+
}
802+
}
803+
804+
ControlFlow::Continue(())
805+
}
806+
}
807+
808+
impl SqlStatementRewriteRule for FixVersionColumnName {
809+
fn rewrite(&self, mut s: Statement) -> Statement {
810+
let mut visitor = FixVersionColumnNameVisitor;
811+
let _ = s.visit(&mut visitor);
812+
813+
s
814+
}
815+
}
816+
773817
#[cfg(test)]
774818
mod tests {
775819
use super::*;
@@ -1021,4 +1065,16 @@ mod tests {
10211065
"SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
10221066
"SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
10231067
}
1068+
1069+
#[test]
1070+
fn test_version_rewrite() {
1071+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1072+
1073+
assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1074+
1075+
// Make sure we don't rewrite things we should leave alone
1076+
assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1077+
assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1078+
assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1079+
}
10241080
}

0 commit comments

Comments
 (0)