Welcome to Software Development on Codidact!
Will you help us build our independent community of developers helping developers? We're small and trying to grow. We welcome questions about all aspects of software development, from design to code to QA and more. Got questions? Got answers? Got code you'd like someone to review? Please join us.
How to get conditional running cumulative sum based on current row and previous rows?
How do I perform a running cumulative sum that is based on a condition involving the current row and previous rows?
Given the following table:
acc | value | threshold
3 | 1 | 1
1 | 2 | 2
2 | 3 | 2
I would like to find the cumulative sum of acc
if value >= threshold
, for all value
s from the start to the current row. The expected output should be 3, 1, 3
.
That is, the equivalent python code might look like:
for i in len(df):
for j in range(i):
if df[j].value >= df[i].threshold:
df[i].cumsum += df[j].value
I tried using a windowed sum:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
df = spark.createDataFrame([(3, 1, 1), (1, 2, 2), (2, 3, 2)], ["acc", "value", "threshold"])
window = Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
display(df.withColumn("output", F.sum(F.when(F.col("value") >= F.col("threshold"), F.col("acc"))).over(window)))
But this gave 3, 4, 6
, because it was comparing against the same threshold
on each row.
1 answer
- Use a
collect_list
collecting the values from all preceding rows up to the current row into a struct - Then
filter
on that struct based on its value and the current row's threshold - Use
aggregate
to calculate the result based on adding the struct'sacc
field
Note that doing so may reorder the output so I added an order column.
import pyspark.sql.functions as F
df = spark.createDataFrame([(1, 3, 1, 1), (2, 1, 2, 2), (3, 2, 3, 2)], ["order", "acc", "value", "threshold"])
display(
df
.withColumn("output", F.expr("""
aggregate(
filter(
collect_list(struct(acc, value)) OVER (ORDER BY order ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW),
s -> s.value >= threshold
),
0L,
(output, s) -> output + s.acc
)
"""))
.orderBy("order")
)
order | acc | value | threshold | output |
---|---|---|---|---|
1 | 3 | 1 | 1 | 3 |
2 | 1 | 2 | 2 | 1 |
3 | 2 | 3 | 2 | 3 |
3 comment threads