From 9ddd7c9f95f1f0c75e36b1fbe9d2dec3a9fbba22 Mon Sep 17 00:00:00 2001 From: michaelfyc Date: Sat, 10 Dec 2022 12:54:31 +0800 Subject: [PATCH 1/4] feat(core): define WindowFunctionExpression and its visit pattern --- .../ktorm/expression/SqlExpressionVisitor.kt | 22 +++ .../org/ktorm/expression/SqlExpressions.kt | 180 ++++++++++++++++++ .../org/ktorm/expression/SqlFormatter.kt | 43 +++++ 3 files changed, 245 insertions(+) diff --git a/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressionVisitor.kt b/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressionVisitor.kt index 007a620c..9bfd09f3 100644 --- a/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressionVisitor.kt +++ b/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressionVisitor.kt @@ -16,6 +16,8 @@ package org.ktorm.expression +import org.ktorm.database.DialectFeatureNotSupportedException + /** * Base class designed to visit or modify SQL expression trees using visitor pattern. * @@ -68,6 +70,7 @@ public open class SqlExpressionVisitor { is BetweenExpression<*> -> visitBetween(expr) is ArgumentExpression -> visitArgument(expr) is FunctionExpression -> visitFunction(expr) + is WindowFunctionExpression -> visitWindowFunction(expr) is CaseWhenExpression -> visitCaseWhen(expr) else -> visitUnknown(expr) } @@ -304,6 +307,25 @@ public open class SqlExpressionVisitor { } } + protected open fun visitWindow(expr: WindowExpression): WindowExpression { + return expr + } + + protected open fun visitWindowFunction(expr: WindowFunctionExpression): WindowFunctionExpression { + val arguments = visitExpressionList(expr.arguments) + check(expr.window != null) { + throw DialectFeatureNotSupportedException("no anonymous or named windows found in window function expression `${expr.functionName}`.") + } + val window = visitWindow(expr.window) + if (arguments === expr.arguments && expr.window === window) { + return expr + } + return expr.copy( + arguments = arguments, + window = window + ) + } + protected open fun visitCaseWhen(expr: CaseWhenExpression): CaseWhenExpression { val operand = expr.operand?.let { visitScalar(it) } val whenClauses = visitWhenClauses(expr.whenClauses) diff --git a/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressions.kt b/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressions.kt index 700378ed..f297aa5d 100644 --- a/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressions.kt +++ b/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlExpressions.kt @@ -569,6 +569,186 @@ public data class FunctionExpression( override val extraProperties: Map = emptyMap() ) : ScalarExpression() +/** + * The enum of window function type in a [WindowExpression]. + */ +public enum class WindowFunctionType(private val type: String) { + // aggregate + /** + * The min function, translated to `min(column)` in SQL. + */ + MIN("min"), + + /** + * The max function, translated to `max(column)` in SQL. + */ + MAX("max"), + + /** + * The avg function, translated to `avg(column)` in SQL. + */ + AVG("avg"), + + /** + * The sum function, translated to `sum(column)` in SQL. + */ + SUM("sum"), + + /** + * The count function, translated to `count(column)` in SQL. + */ + COUNT("count"), + + // non-aggregate + + /** + * The cume_dist function, translated to `cume_dist()` in SQL. + */ + CUME_DIST("cume_dist"), + + /** + * The dense_rank function, translated to `dense_rank()` in SQL. + */ + DENSE_RANK("dense_rank"), + + /** + * The first_value function, translated to `first_value(column)` in SQL. + */ + FIRST_VALUE("first_value"), + + /** + * The lag function, translated to `lag(column, offset, default_value)` in SQL. + */ + LAG("lag"), + + /** + * The last_value function, translated to `last_value(column)` in SQL. + */ + LAST_VALUE("last_value"), + + /** + * The lead function, translated to `lead(column, offset, default_value)` in SQL. + */ + LEAD("lead"), + + /** + * The nth_value function, translated to `nth_value(column, n)` in SQL. + */ + NTH_VALUE("nth_value"), + + /** + * The ntile function, translated to `ntile(n)` in SQL. + */ + NTILE("ntile"), + + /** + * The percent_rank function, translated to `percent_rank()` in SQL. + */ + PERCENT_RANK("percent_rank"), + + /** + * The rank function, translated to `rank()` in SQL. + */ + RANK("rank"), + + /** + * The row_number function, translated to `row_number()` in SQL. + */ + ROW_NUMBER("row_number"); + + override fun toString(): String { + return type + } +} + + +/** + * Window function expression, represents a SQL window function call. + * + * @property functionName the name of the window function. + * @property arguments the argument passed to the window function. + * @property window window specification of the window function. + * @since 3.6 + */ +public data class WindowFunctionExpression( + val functionName: WindowFunctionType, + val arguments: List>, + val window: WindowExpression?, + override val sqlType: SqlType, + override val isLeafNode: Boolean = false, + override val extraProperties: Map = emptyMap() +) : ScalarExpression() + +/** + * Window expression, represents either an anonymous or named window. + * + * @property partitionArguments column expression passed to the partition clause. + * @property orderByExpressions column expression passed to the orderBy clause. + * @property frameUnit frame unit of a window frame clause + * @property frameExpression frame clause of the window function. + * + * @since 3.6 + */ +public data class WindowExpression( + val partitionArguments: List>, + val orderByExpressions: List, + val frameUnit:FrameUnitType?, + val frameExpression: Pair,FrameExpression<*>?>?, + override val isLeafNode: Boolean = false, + override val extraProperties: Map = emptyMap() +): SqlExpression() + +/** + * The enum type of frame unit in [WindowExpression]. + * + * @since 3.6 + */ +public enum class FrameUnitType(private val type: String){ + ROWS("rows"), + RANGE("range"), + GROUPS("groups"), + ROWS_BETWEEN("rows between"), + RANGE_BETWEEN("range between"), + GROUPS_BETWEEN("groups between"); + + override fun toString(): String { + return type + } +} + +/** + * The enum type of frame extent type in [FrameExpression]. + * + * @since 3.6 + */ +public enum class FrameExtentType(private val type: String) { + CURRENT_ROW("current row"), + UNBOUNDED_PRECEDING("unbounded preceding"), + UNBOUNDED_FOLLOWING("unbounded following"), + PRECEDING("preceding"), + FOLLOWING("following"); + + override fun toString(): String { + return type + } +} + +/** + * Frame expression, represents a SQL window function frame clause. + * + * @property frameExtentType frame extent type of a frame clause. + * @property argument frame argument passed to frame clause. + * + * @since 3.6 + */ +public data class FrameExpression( + val frameExtentType: FrameExtentType, + val argument: ScalarExpression<*>?, + override val sqlType: SqlType, + override val isLeafNode: Boolean = false, + override val extraProperties: Map = emptyMap() +) : ScalarExpression() + /** * Case-when expression, represents a SQL case-when clause. * diff --git a/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlFormatter.kt b/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlFormatter.kt index 7f7eff27..fa5d93c7 100644 --- a/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlFormatter.kt +++ b/ktorm-core/src/main/kotlin/org/ktorm/expression/SqlFormatter.kt @@ -537,6 +537,49 @@ public abstract class SqlFormatter( return expr } + override fun visitWindow(expr: WindowExpression): WindowExpression{ + if (expr.partitionArguments.isNotEmpty()) { + writeKeyword("partition by ") + visitExpressionList(expr.partitionArguments) + } + + if (expr.orderByExpressions.isNotEmpty()) { + writeKeyword("order by ") + visitOrderByList(expr.orderByExpressions) + } + if (expr.frameUnit != null) { + writeKeyword("${expr.frameUnit} ") + if (expr.frameExpression != null) { + val first = expr.frameExpression.first + val second = expr.frameExpression.second + first.argument?.let { visit(it) } + writeKeyword("${first.frameExtentType} ") + if(second!=null){ + writeKeyword("and ") + second.argument?.let { visit(it) } + writeKeyword("${second.frameExtentType}") + } + } + } + removeLastBlank() + return expr + } + + override fun visitWindowFunction(expr: WindowFunctionExpression): WindowFunctionExpression { + writeKeyword("${expr.functionName}(") + visitExpressionList(expr.arguments) + removeLastBlank() + writeKeyword(") over (") + check(expr.window != null) { + throw DialectFeatureNotSupportedException("no anonymous or named windows found in window function expression `${expr.functionName}`.") + } + + visitWindow(expr.window) + + write(")") + return expr + } + override fun visitCaseWhen(expr: CaseWhenExpression): CaseWhenExpression { writeKeyword("case ") From f3cdfee3c3f448f83ffe2ee502eaa1cf8ea7bbf1 Mon Sep 17 00:00:00 2001 From: michaelfyc Date: Sat, 10 Dec 2022 12:56:10 +0800 Subject: [PATCH 2/4] feat(mysql): implement common window function expressions for MySQL dialect --- .../ktorm/support/mysql/WindowFunctions.kt | 374 ++++++++++++++++++ 1 file changed, 374 insertions(+) create mode 100644 ktorm-support-mysql/src/main/kotlin/org/ktorm/support/mysql/WindowFunctions.kt diff --git a/ktorm-support-mysql/src/main/kotlin/org/ktorm/support/mysql/WindowFunctions.kt b/ktorm-support-mysql/src/main/kotlin/org/ktorm/support/mysql/WindowFunctions.kt new file mode 100644 index 00000000..43d6e5da --- /dev/null +++ b/ktorm-support-mysql/src/main/kotlin/org/ktorm/support/mysql/WindowFunctions.kt @@ -0,0 +1,374 @@ +/* + * Copyright 2018-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.ktorm.support.mysql + +import org.ktorm.expression.AggregateExpression +import org.ktorm.expression.ArgumentExpression +import org.ktorm.expression.FrameExpression +import org.ktorm.expression.FrameExtentType +import org.ktorm.expression.FrameUnitType +import org.ktorm.expression.OrderByExpression +import org.ktorm.expression.WindowExpression +import org.ktorm.expression.WindowFunctionExpression +import org.ktorm.expression.WindowFunctionType +import org.ktorm.schema.ColumnDeclaring +import org.ktorm.schema.SqlType + +/** + * since 3.6.0 + * Window functions are available since MySQL 5.7 + * reference: https://dev.mysql.com/doc/refman/8.0/en/window-functions.html + * + */ + + +/** + * Most MySQL aggregate functions also can be used as window functions. + */ +public infix fun AggregateExpression.over(window: WindowExpression): WindowFunctionExpression { + val arguments = if (this.argument != null) { + listOf(this.argument!!) + } else { + emptyList() + } + return WindowFunctionExpression( + WindowFunctionType.valueOf(this.type.name), + arguments, + window, + this.sqlType + ) +} + +public infix fun WindowFunctionExpression.over(window: WindowExpression): WindowFunctionExpression { + return WindowFunctionExpression( + functionName, + arguments, + window, + this.sqlType + ) +} + +public fun partitionBy(vararg columns: ColumnDeclaring<*>): WindowExpression { + return WindowExpression( + partitionArguments = columns.map { it.asExpression() }, + orderByExpressions = emptyList(), + frameUnit = null, + frameExpression = null, + ) +} + +public fun orderBy(vararg orderByExpression: OrderByExpression): WindowExpression { + return WindowExpression( + partitionArguments = emptyList(), + orderByExpressions = orderByExpression.asList(), + frameUnit = null, + frameExpression = null, + ) +} + + +public fun WindowExpression.orderBy(vararg orderByExpression: OrderByExpression): WindowExpression { + return WindowExpression( + partitionArguments, + orderByExpression.asList(), + null, + null + ) +} + +public fun Int.preceding(): FrameExpression { + return FrameExpression( + frameExtentType = FrameExtentType.PRECEDING, + argument = ArgumentExpression( + value = this, + sqlType = SqlType.of()!! + ), + sqlType = SqlType.of()!! + ) +} + +public fun Int.following(): FrameExpression { + return FrameExpression( + frameExtentType = FrameExtentType.FOLLOWING, + argument = ArgumentExpression( + value = this, + sqlType = SqlType.of()!! + ), + sqlType = SqlType.of()!! + ) +} + +/** + * Translated to MySQL reserved keyword `UNBOUNDED PRECEDING`. + */ +public val UNBOUNDED_PRECEDING: FrameExpression = FrameExpression( + frameExtentType = FrameExtentType.UNBOUNDED_PRECEDING, + argument = null, + sqlType = SqlType.of()!! +) + +/** + * Translated to MySQL reserved keyword `UNBOUNDED FOLLOWING`. + */ +public val UNBOUNDED_FOLLOWING: FrameExpression = FrameExpression( + frameExtentType = FrameExtentType.UNBOUNDED_FOLLOWING, + argument = null, + sqlType = SqlType.of()!! +) + +/** + * Translated to MySQL reserved key word `CURRENT ROW`. + */ +public val CURRENT_ROW: FrameExpression = FrameExpression( + frameExtentType = FrameExtentType.CURRENT_ROW, + argument = null, + sqlType = SqlType.of()!! +) + +/** + * Translated to MySQL frame unit `rows between`. + */ +public fun WindowExpression.rowsBetween( + left: FrameExpression, + right: FrameExpression +): WindowExpression { + return WindowExpression( + partitionArguments, + orderByExpressions, + FrameUnitType.ROWS_BETWEEN, + Pair(left, right) + ) +} + +/** + * Translated to MySQL frame unit `range between`. + */ +public fun WindowExpression.rangeBetween( + left: FrameExpression, + right: FrameExpression +): WindowExpression { + return WindowExpression( + partitionArguments, + orderByExpressions, + FrameUnitType.RANGE_BETWEEN, + Pair(left, right) + ) +} + +/** + * Translated to MySQL frame unit `range`. + */ +public fun WindowExpression.range( + frameExpression: FrameExpression, +): WindowExpression { + return WindowExpression( + partitionArguments, + orderByExpressions, + FrameUnitType.RANGE, + Pair(frameExpression, null) + ) +} + +/** + * Translated to MySQL frame unit `rows`. + */ +public fun WindowExpression.row( + frameExpression: FrameExpression, +): WindowExpression { + return WindowExpression( + partitionArguments, + orderByExpressions, + FrameUnitType.ROWS, + Pair(frameExpression, null) + ) +} + +/** + * MySQL rank window function, translated to `rank()`. + */ +public fun rank(): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.RANK, + arguments = emptyList(), + window = null, + sqlType = SqlType.of()!! + ) +} + +/** + * MySQL row_number window function, translated to `row_number()`. + */ +public fun rowNumber(): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.ROW_NUMBER, + arguments = emptyList(), + window = null, + sqlType = SqlType.of()!! + ) +} + +/** + * MySQL dense_rank window function, translated to `dense_rank()`. + */ +public fun denseRank(): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.DENSE_RANK, + arguments = emptyList(), + window = null, + sqlType = SqlType.of()!! + ) +} + + +/** + * MySQL percent_rank window function, translated to `percent_rank()`. + */ +public fun percentRank(): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.PERCENT_RANK, + arguments = emptyList(), + window = null, + sqlType = SqlType.of()!! + ) +} + +/** + * MySQL cume_dist window function, translated to `cume_dist()`. + */ +public fun cumeDist(): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.CUME_DIST, + arguments = emptyList(), + window = null, + sqlType = SqlType.of()!! + ) +} + +/** + * MySQL first_value window function, translated to `first_value(column)`. + */ +public fun firstValue(column: ColumnDeclaring): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.FIRST_VALUE, + arguments = listOf(column.asExpression()), + window = null, + sqlType = column.sqlType + ) +} + +/** + * MySQL last_value window function, translated to `last_value(column)`. + */ +public fun lastValue(column: ColumnDeclaring): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.LAST_VALUE, + arguments = listOf(column.asExpression()), + window = null, + sqlType = column.sqlType + ) +} + +/** + * MySQL ntile window function, translated to `ntile(n)`. + */ +public fun ntile(n: Int): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.NTILE, + arguments = listOf( + ArgumentExpression( + n, + SqlType.of() ?: error("Cannot detect the param's SqlType, please specify manually.") + ) + ), + window = null, + sqlType = SqlType.of()!! + ) +} + +/** + * MySQL nth_value window function, translated to `nth_value(column, n)`. + */ +public fun nthValue(column: ColumnDeclaring, n: Int): WindowFunctionExpression { + return WindowFunctionExpression( + functionName = WindowFunctionType.NTH_VALUE, + arguments = listOf( + column.asExpression(), + ArgumentExpression(n, SqlType.of() ?: error("Cannot detect the param's SqlType, please specify manually.")) + ), + window = null, + sqlType = column.sqlType + ) +} + +/** + * MySQL lead window function, translated to `lead(column, offset, )`. + */ +public fun lead( + column: ColumnDeclaring, + offset: Int, + defaultValue: Int? = null +): WindowFunctionExpression { + val arguments = mutableListOf( + column.asExpression(), + ArgumentExpression(offset, SqlType.of() ?: error("Cannot detect the param's SqlType, please specify manually.")) + ) + if (defaultValue != null) { + arguments.add( + ArgumentExpression( + defaultValue, + SqlType.of() ?: error("Cannot detect the param's SqlType, please specify manually.") + ) + ) + } + + return WindowFunctionExpression( + functionName = WindowFunctionType.LEAD, + arguments = arguments, + window = null, + sqlType = column.sqlType + ) +} + +/** + * MySQL lag window function, translated to `lag(column, offset, )`. + */ +public fun lag( + column: ColumnDeclaring, + offset: Int, + defaultValue: Int? = null +): WindowFunctionExpression { + val arguments = mutableListOf( + column.asExpression(), + ArgumentExpression(offset, SqlType.of() ?: error("Cannot detect the param's SqlType, please specify manually.")) + ) + if (defaultValue != null) { + arguments.add( + ArgumentExpression( + defaultValue, + SqlType.of() ?: error("Cannot detect the param's SqlType, please specify manually.") + ) + ) + } + + return WindowFunctionExpression( + functionName = WindowFunctionType.LAG, + arguments = arguments, + window = null, + sqlType = column.sqlType + ) +} + From f69d8e63b8f2289471755e6475feff46cce11cc1 Mon Sep 17 00:00:00 2001 From: michaelfyc Date: Sat, 10 Dec 2022 12:57:00 +0800 Subject: [PATCH 3/4] test(mysql): add window functions unit test for MySQL dialect --- .../org/ktorm/support/mysql/CommonTest.kt | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/ktorm-support-mysql/src/test/kotlin/org/ktorm/support/mysql/CommonTest.kt b/ktorm-support-mysql/src/test/kotlin/org/ktorm/support/mysql/CommonTest.kt index 0e3f34ce..ea91fb50 100644 --- a/ktorm-support-mysql/src/test/kotlin/org/ktorm/support/mysql/CommonTest.kt +++ b/ktorm-support-mysql/src/test/kotlin/org/ktorm/support/mysql/CommonTest.kt @@ -3,6 +3,7 @@ package org.ktorm.support.mysql import org.hamcrest.CoreMatchers.equalTo import org.hamcrest.MatcherAssert.assertThat import org.junit.Test +import org.ktorm.database.DialectFeatureNotSupportedException import org.ktorm.database.use import org.ktorm.dsl.* import org.ktorm.entity.* @@ -14,6 +15,8 @@ import java.util.concurrent.ExecutionException import java.util.concurrent.Executors import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith /** * Created by vince on Dec 12, 2018. @@ -320,6 +323,47 @@ class CommonTest : BaseMySqlTest() { assert(name == "VINCE") } + @Test + fun testWindowFunctions_0() { + // for those that are aggregate functions + val sum = (sum(Employees.salary) over partitionBy(Employees.departmentId)).aliased("sum") + val departmentSum = database.from(Employees) + .selectDistinct(Employees.departmentId, sum) + .associate { Pair(it[Employees.departmentId]!!, it[sum]!!) } + assertEquals(mapOf(1 to 150L, 2 to 300L), departmentSum) + + // for those that are non-aggregate functions + val rank = (rank() over partitionBy(Employees.departmentId).orderBy(Employees.salary.desc())).aliased("rank") + val employeeSalaryRanks = database.from(Employees) + .select(Employees.name, Employees.departmentId, rank) + .map { + Triple(it[Employees.name], it[Employees.departmentId], it[rank]) + } + val topSalaryEmployees = employeeSalaryRanks.filter { + it.third == 1 + }.map { it.first }.toSet() + assertEquals(setOf("vince", "tom"), topSalaryEmployees) + + // for those non-aggregate functions that require parameters + val group = (ntile(2) over orderBy(Employees.departmentId.asc())).aliased("group_num") + val employeeGroup = database.from(Employees).select(Employees.id, group).associate { + Pair(it[Employees.id]!!, it[group]!!) + } + assertEquals(mapOf(1 to 1, 2 to 1, 3 to 2, 4 to 2), employeeGroup) + } + + @Test + fun testWindowFunction_1(){ + // An exception should be thrown when no window is specified for a window function + assertFailsWith { + val rank = rank().aliased("rank") + database.from(Employees) + .select(Employees.name, Employees.departmentId, rank) + .forEach { + println("${it[rank]}") + } + } + } @Test fun testIf() { From a073b7e13c99c883da2efc10eba33d543ef5ab52 Mon Sep 17 00:00:00 2001 From: michaelfyc Date: Sat, 10 Dec 2022 13:27:12 +0800 Subject: [PATCH 4/4] chore: update developer info --- buildSrc/src/main/kotlin/ktorm.maven-publish.gradle.kts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/buildSrc/src/main/kotlin/ktorm.maven-publish.gradle.kts b/buildSrc/src/main/kotlin/ktorm.maven-publish.gradle.kts index b85e48a0..82492f18 100644 --- a/buildSrc/src/main/kotlin/ktorm.maven-publish.gradle.kts +++ b/buildSrc/src/main/kotlin/ktorm.maven-publish.gradle.kts @@ -146,6 +146,11 @@ publishing { name.set("夜里的向日葵") email.set("641571835@qq.com") } + developer { + id.set("michaelfyc") + name.set("michaelfyc") + email.set("michael.fyc@outlook.com") + } } } }