Support for splitting nested branching operators within policies (#1136)
* Support for splitting nested branching operators within policies
* Introduce an ast.Heights() helper
* Updated tests and expanded flattening to all calls
* Added test case for comprehension pruning during unnest
diff --git a/common/ast/ast.go b/common/ast/ast.go
index b807669..62c09cf 100644
--- a/common/ast/ast.go
+++ b/common/ast/ast.go
@@ -160,6 +160,13 @@
return visitor.maxID + 1
}
+// Heights computes the heights of all AST expressions and returns a map from expression id to height.
+func Heights(a *AST) map[int64]int {
+ visitor := make(heightVisitor)
+ PostOrderVisit(a.Expr(), visitor)
+ return visitor
+}
+
// NewSourceInfo creates a simple SourceInfo object from an input common.Source value.
func NewSourceInfo(src common.Source) *SourceInfo {
var lineOffsets []int32
@@ -455,3 +462,74 @@
v.maxID = e.ID()
}
}
+
+type heightVisitor map[int64]int
+
+// VisitExpr computes the height of a given node as the max height of its children plus one.
+//
+// Identifiers and literals are treated as having a height of zero.
+func (hv heightVisitor) VisitExpr(e Expr) {
+ // default includes IdentKind, LiteralKind
+ hv[e.ID()] = 0
+ switch e.Kind() {
+ case SelectKind:
+ hv[e.ID()] = 1 + hv[e.AsSelect().Operand().ID()]
+ case CallKind:
+ c := e.AsCall()
+ height := hv.maxHeight(c.Args()...)
+ if c.IsMemberFunction() {
+ tHeight := hv[c.Target().ID()]
+ if tHeight > height {
+ height = tHeight
+ }
+ }
+ hv[e.ID()] = 1 + height
+ case ListKind:
+ l := e.AsList()
+ hv[e.ID()] = 1 + hv.maxHeight(l.Elements()...)
+ case MapKind:
+ m := e.AsMap()
+ hv[e.ID()] = 1 + hv.maxEntryHeight(m.Entries()...)
+ case StructKind:
+ s := e.AsStruct()
+ hv[e.ID()] = 1 + hv.maxEntryHeight(s.Fields()...)
+ case ComprehensionKind:
+ comp := e.AsComprehension()
+ hv[e.ID()] = 1 + hv.maxHeight(comp.IterRange(), comp.AccuInit(), comp.LoopCondition(), comp.LoopStep(), comp.Result())
+ }
+}
+
+// VisitEntryExpr computes the max height of a map or struct entry and associates the height with the entry id.
+func (hv heightVisitor) VisitEntryExpr(e EntryExpr) {
+ hv[e.ID()] = 0
+ switch e.Kind() {
+ case MapEntryKind:
+ me := e.AsMapEntry()
+ hv[e.ID()] = hv.maxHeight(me.Value(), me.Key())
+ case StructFieldKind:
+ sf := e.AsStructField()
+ hv[e.ID()] = hv[sf.Value().ID()]
+ }
+}
+
+func (hv heightVisitor) maxHeight(exprs ...Expr) int {
+ max := 0
+ for _, e := range exprs {
+ h := hv[e.ID()]
+ if h > max {
+ max = h
+ }
+ }
+ return max
+}
+
+func (hv heightVisitor) maxEntryHeight(entries ...EntryExpr) int {
+ max := 0
+ for _, e := range entries {
+ h := hv[e.ID()]
+ if h > max {
+ max = h
+ }
+ }
+ return max
+}
diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go
index a4a4a57..7a1c6a1 100644
--- a/common/ast/ast_test.go
+++ b/common/ast/ast_test.go
@@ -339,6 +339,31 @@
}
}
+func TestHeights(t *testing.T) {
+ tests := []struct {
+ expr string
+ height int
+ }{
+ {`'a' == 'b'`, 1},
+ {`'a'.size()`, 1},
+ {`[1, 2].size()`, 2},
+ {`size('a')`, 1},
+ {`has({'a': 1}.a)`, 2},
+ {`{'a': 1}`, 1},
+ {`{'a': 1}['a']`, 2},
+ {`[1, 2, 3].exists(i, i % 2 == 1)`, 4},
+ {`google.expr.proto3.test.TestAllTypes{}`, 1},
+ {`google.expr.proto3.test.TestAllTypes{repeated_int32: [1, 2]}`, 2},
+ }
+ for _, tst := range tests {
+ checked := mustTypeCheck(t, tst.expr)
+ maxHeight := ast.Heights(checked)[checked.Expr().ID()]
+ if maxHeight != tst.height {
+ t.Errorf("ast.Heights(%q) got max height %d, wanted %d", tst.expr, maxHeight, tst.height)
+ }
+ }
+}
+
func mockRelativeSource(t testing.TB, text string, lineOffsets []int32, baseLocation common.Location) common.Source {
t.Helper()
return &mockSource{
diff --git a/common/ast/navigable.go b/common/ast/navigable.go
index d7a90fb..13e5777 100644
--- a/common/ast/navigable.go
+++ b/common/ast/navigable.go
@@ -237,8 +237,13 @@
case StructKind:
s := expr.AsStruct()
for _, f := range s.Fields() {
- visitor.VisitEntryExpr(f)
+ if order == preOrder {
+ visitor.VisitEntryExpr(f)
+ }
visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth)
+ if order == postOrder {
+ visitor.VisitEntryExpr(f)
+ }
}
}
if order == postOrder {
diff --git a/policy/compiler.go b/policy/compiler.go
index 93505ff..bdf495a 100644
--- a/policy/compiler.go
+++ b/policy/compiler.go
@@ -198,7 +198,8 @@
if iss.Err() != nil {
return nil, iss
}
- composer := NewRuleComposer(env, p)
+ // An error cannot happen when composing without supplying options
+ composer, _ := NewRuleComposer(env)
return composer.Compose(rule)
}
diff --git a/policy/compiler_test.go b/policy/compiler_test.go
index cc3e80f..b318d2d 100644
--- a/policy/compiler_test.go
+++ b/policy/compiler_test.go
@@ -31,8 +31,55 @@
func TestCompile(t *testing.T) {
for _, tst := range policyTests {
- t.Run(tst.name, func(t *testing.T) {
- r := newRunner(t, tst.name, tst.expr, tst.parseOpts, tst.envOpts...)
+ tc := tst
+ t.Run(tc.name, func(t *testing.T) {
+ r := newRunner(tc.name, tc.expr, tc.parseOpts)
+ env, ast, iss := r.compile(t, tc.envOpts, []CompilerOption{})
+ if iss.Err() != nil {
+ t.Fatalf("Compile(%s) failed: %v", r.name, iss.Err())
+ }
+ r.setup(t, env, ast)
+ r.run(t)
+ })
+ }
+}
+
+func TestRuleComposerError(t *testing.T) {
+ env, err := cel.NewEnv()
+ if err != nil {
+ t.Fatalf("NewEnv() failed: %v", err)
+ }
+ _, err = NewRuleComposer(env, ExpressionUnnestHeight(-1))
+ if err == nil || !strings.Contains(err.Error(), "invalid unnest") {
+ t.Errorf("NewRuleComposer() got %v, wanted 'invalid unnest'", err)
+ }
+}
+
+func TestRuleComposerUnnest(t *testing.T) {
+ for _, tst := range composerUnnestTests {
+ tc := tst
+ t.Run(tc.name, func(t *testing.T) {
+ r := newRunner(tc.name, tc.expr, []ParserOption{})
+ env, rule, iss := r.compileRule(t)
+ if iss.Err() != nil {
+ t.Fatalf("CompileRule() failed: %v", iss.Err())
+ }
+ rc, err := NewRuleComposer(env, tc.composerOpts...)
+ if err != nil {
+ t.Fatalf("NewRuleComposer() failed: %v", err)
+ }
+ ast, iss := rc.Compose(rule)
+ if iss.Err() != nil {
+ t.Fatalf("Compose(rule) failed: %v", iss.Err())
+ }
+ unparsed, err := cel.AstToString(ast)
+ if err != nil {
+ t.Fatalf("cel.AstToString() failed: %v", err)
+ }
+ if normalize(unparsed) != normalize(tc.composed) {
+ t.Errorf("cel.AstToString() got %s, wanted %s", unparsed, tc.composed)
+ }
+ r.setup(t, env, ast)
r.run(t)
})
}
@@ -40,7 +87,8 @@
func TestCompileError(t *testing.T) {
for _, tst := range policyErrorTests {
- _, _, iss := compile(t, tst.name, []ParserOption{}, []cel.EnvOption{}, tst.compilerOpts)
+ policy := parsePolicy(t, tst.name, []ParserOption{})
+ _, _, iss := compile(t, tst.name, policy, []cel.EnvOption{}, tst.compilerOpts)
if iss.Err() == nil {
t.Fatalf("compile(%s) did not error, wanted %s", tst.name, tst.err)
}
@@ -98,7 +146,8 @@
wantError := `ERROR: testdata/required_labels/policy.yaml:15:8: error configuring compiler option: nested expression limit must be non-negative, non-zero value: -1
| name: "required_labels"
| .......^`
- _, _, iss := compile(t, policyName, []ParserOption{}, []cel.EnvOption{}, []CompilerOption{MaxNestedExpressions(-1)})
+ policy := parsePolicy(t, policyName, []ParserOption{})
+ _, _, iss := compile(t, policyName, policy, []cel.EnvOption{}, []CompilerOption{MaxNestedExpressions(-1)})
if iss.Err() == nil {
t.Fatalf("compile(%s) did not error, wanted %s", policyName, wantError)
}
@@ -109,55 +158,40 @@
func BenchmarkCompile(b *testing.B) {
for _, tst := range policyTests {
- r := newRunner(b, tst.name, tst.expr, tst.parseOpts, tst.envOpts...)
+ r := newRunner(tst.name, tst.expr, tst.parseOpts)
+ env, ast, iss := r.compile(b, tst.envOpts, []CompilerOption{})
+ if iss.Err() != nil {
+ b.Fatalf("Compile() failed: %v", iss.Err())
+ }
+ r.setup(b, env, ast)
r.bench(b)
}
}
-func newRunner(t testing.TB, name, expr string, parseOpts []ParserOption, opts ...cel.EnvOption) *runner {
- r := &runner{
+func newRunner(name, expr string, parseOpts []ParserOption, opts ...cel.EnvOption) *runner {
+ return &runner{
name: name,
- envOpts: opts,
parseOpts: parseOpts,
expr: expr}
- r.setup(t)
- return r
}
type runner struct {
- name string
- envOpts []cel.EnvOption
- parseOpts []ParserOption
- compilerOpts []CompilerOption
- env *cel.Env
- expr string
- prg cel.Program
+ name string
+ parseOpts []ParserOption
+ env *cel.Env
+ expr string
+ prg cel.Program
}
-func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast {
- t.Helper()
- out, iss := env.Compile(expr)
- if iss.Err() != nil {
- t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err())
- }
- return out
+func (r *runner) compile(t testing.TB, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
+ policy := parsePolicy(t, r.name, r.parseOpts)
+ return compile(t, r.name, policy, envOpts, compilerOpts)
}
-func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
+func (r *runner) compileRule(t testing.TB) (*cel.Env, *CompiledRule, *cel.Issues) {
t.Helper()
- config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name))
- srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name))
- parser, err := NewParser(parseOpts...)
- if err != nil {
- t.Fatalf("NewParser() failed: %v", err)
- }
- policy, iss := parser.Parse(srcFile)
- if iss.Err() != nil {
- t.Fatalf("Parse() failed: %v", iss.Err())
- }
- if policy.name.Value != name {
- t.Errorf("policy name is %v, wanted %q", policy.name, name)
- }
+ config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", r.name))
+ policy := parsePolicy(t, r.name, r.parseOpts)
env, err := cel.NewCustomEnv(
cel.OptionalTypes(),
cel.EnableMacroCallTracking(),
@@ -166,26 +200,17 @@
if err != nil {
t.Fatalf("cel.NewEnv() failed: %v", err)
}
- // Configure any custom environment options.
- env, err = env.Extend(envOpts...)
- if err != nil {
- t.Fatalf("env.Extend() with env options %v, failed: %v", config, err)
- }
// Configure declarations
env, err = env.Extend(FromConfig(config))
if err != nil {
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
}
- ast, iss := Compile(env, policy, compilerOpts...)
- return env, ast, iss
+ rule, iss := CompileRule(env, policy)
+ return env, rule, iss
}
-func (r *runner) setup(t testing.TB) {
+func (r *runner) setup(t testing.TB, env *cel.Env, ast *cel.Ast) {
t.Helper()
- env, ast, iss := compile(t, r.name, r.parseOpts, r.envOpts, r.compilerOpts)
- if iss.Err() != nil {
- t.Fatalf("Compile(%s) failed: %v", r.name, iss.Err())
- }
pExpr, err := cel.AstToString(ast)
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
@@ -323,6 +348,56 @@
return out
}
+func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast {
+ t.Helper()
+ out, iss := env.Compile(expr)
+ if iss.Err() != nil {
+ t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err())
+ }
+ return out
+}
+
+func parsePolicy(t testing.TB, name string, parseOpts []ParserOption) *Policy {
+ t.Helper()
+ srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name))
+ parser, err := NewParser(parseOpts...)
+ if err != nil {
+ t.Fatalf("NewParser() failed: %v", err)
+ }
+ policy, iss := parser.Parse(srcFile)
+ if iss.Err() != nil {
+ t.Fatalf("Parse() failed: %v", iss.Err())
+ }
+ if policy.name.Value != name {
+ t.Errorf("policy name is %v, wanted %q", policy.name, name)
+ }
+ return policy
+}
+
+func compile(t testing.TB, name string, policy *Policy, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
+ config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name))
+ env, err := cel.NewCustomEnv(
+ cel.OptionalTypes(),
+ cel.EnableMacroCallTracking(),
+ cel.ExtendedValidations(),
+ ext.Bindings())
+ if err != nil {
+ t.Fatalf("cel.NewEnv() failed: %v", err)
+ }
+ // Configure any custom environment options.
+ env, err = env.Extend(envOpts...)
+ if err != nil {
+ t.Fatalf("env.Extend() with env options %v, failed: %v", config, err)
+ }
+ // Configure declarations
+ env, err = env.Extend(FromConfig(config))
+ if err != nil {
+ t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
+ }
+ ast, iss := Compile(env, policy, compilerOpts...)
+ return env, ast, iss
+}
+
func normalize(s string) string {
return strings.ReplaceAll(
strings.ReplaceAll(
diff --git a/policy/composer.go b/policy/composer.go
index be326de..0b9be2a 100644
--- a/policy/composer.go
+++ b/policy/composer.go
@@ -15,7 +15,9 @@
package policy
import (
+ "cmp"
"fmt"
+ "slices"
"strings"
"github.com/google/cel-go/cel"
@@ -24,25 +26,58 @@
"github.com/google/cel-go/common/types"
)
+// ComposerOption is a functional option used to configure a RuleComposer
+type ComposerOption func(*RuleComposer) (*RuleComposer, error)
+
+// ExpressionUnnestHeight determines the height at which nested expressions are split into local
+// variables within the cel.@block declaration.
+func ExpressionUnnestHeight(height int) ComposerOption {
+ return func(c *RuleComposer) (*RuleComposer, error) {
+ if height <= 0 {
+ return nil, fmt.Errorf("invalid unnest height: value must be positive: %d", height)
+ }
+ c.exprUnnestHeight = height
+ return c, nil
+ }
+}
+
// NewRuleComposer creates a rule composer which stitches together rules within a policy into
// a single CEL expression.
-func NewRuleComposer(env *cel.Env, p *Policy) *RuleComposer {
- return &RuleComposer{
+func NewRuleComposer(env *cel.Env, opts ...ComposerOption) (*RuleComposer, error) {
+ composer := &RuleComposer{
env: env,
- p: p,
+ // set the default nesting height to something reasonable.
+ exprUnnestHeight: 25,
}
+ var err error
+ for _, opt := range opts {
+ composer, err = opt(composer)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return composer, nil
}
// RuleComposer optimizes a set of expressions into a single expression.
type RuleComposer struct {
env *cel.Env
- p *Policy
+
+ // exprUnnestHeight determines the height at which nested matches are split into
+ // index variables within a cel.@block index declaration when composing matches under
+ // the first-match semantic.
+ exprUnnestHeight int
}
// Compose stitches together a set of expressions within a CompiledRule into a single CEL ast.
func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) {
ruleRoot, _ := c.env.Compile("true")
- opt := cel.NewStaticOptimizer(&ruleComposerImpl{rule: r, varIndices: []varIndex{}})
+ opt := cel.NewStaticOptimizer(
+ &ruleComposerImpl{
+ rule: r,
+ varIndices: []varIndex{},
+ exprUnnestHeight: c.exprUnnestHeight,
+ })
return opt.Optimize(c.env, ruleRoot)
}
@@ -51,7 +86,7 @@
indexVar string
localVar string
expr ast.Expr
- cv *CompiledVariable
+ celType *types.Type
}
type ruleComposerImpl struct {
@@ -59,7 +94,7 @@
nextVarIndex int
varIndices []varIndex
- maxNestedExpressionLimit int
+ exprUnnestHeight int
}
// Optimize implements an AST optimizer for CEL which composes an expression graph into a single
@@ -68,17 +103,23 @@
// The input to optimize is a dummy expression which is completely replaced according
// to the configuration of the rule composition graph.
ruleExpr := opt.optimizeRule(ctx, opt.rule)
+ // If the rule is deeply nested, it may need to be unnested. This process may generate
+ // additional variables that are included in the `sortedVariables` list.
+ ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr)
+
+ // Collect all variables associated with the rule expression.
allVars := opt.sortedVariables()
// If there were no variables, return the expression.
if len(allVars) == 0 {
return ctx.NewAST(ruleExpr)
}
- // Otherwise populate the block.
+ // Otherwise populate the cel.@block with the variable declarations and wrap the expression
+ // in the block.
varExprs := make([]ast.Expr, len(allVars))
for i, vi := range allVars {
varExprs[i] = vi.expr
- err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.cv.Declaration().Type()))
+ err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType))
if err != nil {
ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error())
}
@@ -156,6 +197,57 @@
})
}
+func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.Expr) ast.Expr {
+ // Split the expr into local variables based on expression height
+ ruleAST := ctx.NewAST(ruleExpr)
+ ruleNav := ast.NavigateAST(ruleAST)
+ // Unnest expressions are ordered from leaf to root via the ast.MatchDescendants call.
+ heights := ast.Heights(ruleAST)
+ unnestMap := map[int64]bool{}
+ unnestExprs := []ast.NavigableExpr{}
+ ast.MatchDescendants(ruleNav, func(e ast.NavigableExpr) bool {
+ // If the expression is a comprehension, then all unnest candidates captured previously that relate
+ // to the comprehension body should be removed from the list of candidate branches for unnesting.
+ if e.Kind() == ast.ComprehensionKind {
+ // This only removes branches from the map, but not from the list of branches.
+ removeIneligibleSubExprs(e, unnestMap)
+ return false
+ }
+ // Otherwise, if the expression is not a call, don't include it.
+ if e.Kind() != ast.CallKind {
+ return false
+ }
+ height := heights[e.ID()]
+ if height < opt.exprUnnestHeight {
+ return false
+ }
+ unnestMap[e.ID()] = true
+ unnestExprs = append(unnestExprs, e)
+ return true
+ })
+
+ slices.SortStableFunc(unnestExprs, func(a, b ast.NavigableExpr) int {
+ heightA := heights[a.ID()]
+ heightB := heights[b.ID()]
+ return cmp.Compare(heightA, heightB)
+ })
+
+ // Prune the expression set to unnest down to only those not included in comprehensions.
+ for idx := 0; idx < len(unnestExprs)-1; idx++ {
+ e := unnestExprs[idx]
+ if present, found := unnestMap[e.ID()]; !found || !present {
+ continue
+ }
+ height := heights[e.ID()]
+ if height < opt.exprUnnestHeight {
+ continue
+ }
+ reduceHeight(heights, e, opt.exprUnnestHeight)
+ opt.registerBranchVariable(ctx, e)
+ }
+ return ruleExpr
+}
+
// registerVariable creates an entry for a variable name within the cel.@block used to enumerate
// variables within composed policy expression.
func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) {
@@ -168,7 +260,23 @@
indexVar: indexVar,
localVar: varName,
expr: varExpr,
- cv: v}
+ celType: v.Declaration().Type()}
+ opt.varIndices = append(opt.varIndices, vi)
+ opt.nextVarIndex++
+}
+
+// registerBranchVariable creates an entry for a variable name within the cel.@block used to unnest
+// a deeply nested logical branch or logical operator.
+func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) {
+ indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex)
+ varExprCopy := ctx.CopyASTAndMetadata(ctx.NewAST(varExpr))
+ vi := varIndex{
+ index: opt.nextVarIndex,
+ indexVar: indexVar,
+ expr: varExprCopy,
+ celType: varExpr.Type(),
+ }
+ ctx.UpdateExpr(varExpr, ctx.NewIdent(vi.indexVar))
opt.varIndices = append(opt.varIndices, vi)
opt.nextVarIndex++
}
@@ -270,6 +378,7 @@
)
}
// The `step` is pruned away by a unconditional non-optional step `s`.
+ // Likely a candidate for dead-code warnings.
return s
}
return newNonOptionalCompositionStep(ctx,
@@ -362,3 +471,42 @@
e.AsCall().FunctionName() == "optional.none" &&
len(e.AsCall().Args()) == 0
}
+
+func removeIneligibleSubExprs(e ast.NavigableExpr, unnestMap map[int64]bool) {
+ for _, id := range comprehensionSubExprIDs(e) {
+ if _, found := unnestMap[id]; found {
+ delete(unnestMap, id)
+ }
+ }
+}
+
+func comprehensionSubExprIDs(e ast.NavigableExpr) []int64 {
+ compre := e.AsComprehension()
+ // Almost the same as e.Children(), but skips the iteration range
+ return enumerateExprIDs(
+ compre.AccuInit().(ast.NavigableExpr),
+ compre.LoopCondition().(ast.NavigableExpr),
+ compre.LoopStep().(ast.NavigableExpr),
+ compre.Result().(ast.NavigableExpr),
+ )
+}
+
+func enumerateExprIDs(exprs ...ast.NavigableExpr) []int64 {
+ ids := make([]int64, 0, len(exprs))
+ for _, e := range exprs {
+ ids = append(ids, e.ID())
+ ids = append(ids, enumerateExprIDs(e.Children()...)...)
+ }
+ return ids
+}
+
+func reduceHeight(heights map[int64]int, e ast.NavigableExpr, amount int) {
+ height := heights[e.ID()]
+ if height < amount {
+ return
+ }
+ heights[e.ID()] = height - amount
+ if parent, hasParent := e.Parent(); hasParent {
+ reduceHeight(heights, parent, amount)
+ }
+}
diff --git a/policy/helper_test.go b/policy/helper_test.go
index 3934301..fbe3183 100644
--- a/policy/helper_test.go
+++ b/policy/helper_test.go
@@ -35,7 +35,6 @@
envOpts []cel.EnvOption
parseOpts []ParserOption
expr string
- expr2 string
}{
{
name: "k8s",
@@ -114,6 +113,19 @@
.or((x > 1) ? optional.of(false) : optional.none()))`,
},
{
+ name: "unnest",
+ expr: `
+ cel.@block([values.filter(x, x > 2)],
+ ((@index0.size() == 0) ? false : @index0.all(x, x % 2 == 0))
+ ? optional.of("some divisible by 2")
+ : (values.map(x, x * 3).exists(x, x % 4 == 0)
+ ? optional.of("at least one divisible by 4")
+ : (values.map(x, x * x * x).exists(x, x % 6 == 0)
+ ? optional.of("at least one power of 6")
+ : optional.none())))
+ `,
+ },
+ {
name: "context_pb",
expr: `
(single_int32 > google.expr.proto3.test.TestAllTypes{single_int64: 10}.single_int64)
@@ -145,7 +157,7 @@
cel.@block([
spec.labels,
@index0.filter(l, !(l in resource.labels)),
- resource.labels.transformList(l, value, l in @index0 && value != @index0[l], l)],
+ resource.labels.transformList(l, value, l in @index0 && value != @index0[l], l)],
(@index1.size() > 0)
? optional.of("missing one or more required labels: %s".format([@index1]))
: ((@index2.size() > 0)
@@ -199,6 +211,140 @@
},
}
+ composerUnnestTests = []struct {
+ name string
+ expr string
+ composed string
+ composerOpts []ComposerOption
+ }{
+ {
+ name: "unnest",
+ composerOpts: []ComposerOption{ExpressionUnnestHeight(2)},
+ composed: `
+ cel.@block([
+ values.filter(x, x > 2),
+ @index0.size() == 0,
+ @index1 ? false : @index0.all(x, x % 2 == 0),
+ values.map(x, x * x * x).exists(x, x % 6 == 0)
+ ? optional.of("at least one power of 6")
+ : optional.none(),
+ values.map(x, x * 3).exists(x, x % 4 == 0)
+ ? optional.of("at least one divisible by 4")
+ : @index3],
+ @index2 ? optional.of("some divisible by 2") : @index4)
+ `,
+ },
+ {
+ name: "required_labels",
+ composerOpts: []ComposerOption{ExpressionUnnestHeight(2)},
+ composed: `
+ cel.@block([
+ spec.labels,
+ @index0.filter(l, !(l in resource.labels)),
+ resource.labels.transformList(l, value, l in @index0 && value != @index0[l], l),
+ @index1.size() > 0,
+ "missing one or more required labels: %s".format([@index1]),
+ @index2.size() > 0,
+ "invalid values provided on one or more labels: %s".format([@index2])],
+ @index3 ? optional.of(@index4) : (@index5 ? optional.of(@index6) : optional.none()))
+ `,
+ },
+ {
+ name: "required_labels",
+ composerOpts: []ComposerOption{ExpressionUnnestHeight(4)},
+ composed: `
+ cel.@block([
+ spec.labels,
+ @index0.filter(l, !(l in resource.labels)),
+ resource.labels.transformList(l, value, l in @index0 && value != @index0[l], l),
+ (@index2.size() > 0)
+ ? optional.of("invalid values provided on one or more labels: %s".format([@index2]))
+ : optional.none()
+ ],
+ (@index1.size() > 0)
+ ? optional.of("missing one or more required labels: %s".format([@index1]))
+ : @index3)`,
+ },
+ {
+ name: "nested_rule2",
+ composerOpts: []ComposerOption{ExpressionUnnestHeight(4)},
+ composed: `
+ cel.@block([
+ ["us", "uk", "es"],
+ {"us": false, "ru": false, "ir": false},
+ resource.origin in @index1 && !(resource.origin in @index0),
+ !(resource.origin in @index0) ? {"banned": "unconfigured_region"} : {}],
+ resource.?user.orValue("").startsWith("bad")
+ ? (@index2 ? {"banned": "restricted_region"} : {"banned": "bad_actor"})
+ : @index3)`,
+ },
+ {
+ name: "nested_rule2",
+ composerOpts: []ComposerOption{ExpressionUnnestHeight(5)},
+ composed: `
+ cel.@block([
+ ["us", "uk", "es"],
+ {"us": false, "ru": false, "ir": false},
+ (resource.origin in @index1 && !(resource.origin in @index0))
+ ? {"banned": "restricted_region"}
+ : {"banned": "bad_actor"}],
+ resource.?user.orValue("").startsWith("bad")
+ ? @index2
+ : (!(resource.origin in @index0)
+ ? {"banned": "unconfigured_region"}
+ : {}))`,
+ },
+ {
+ name: "limits",
+ composerOpts: []ComposerOption{ExpressionUnnestHeight(3)},
+ composed: `
+ cel.@block([
+ "hello",
+ "goodbye",
+ "me",
+ "%s, %s",
+ @index3.format([@index1, @index2]),
+ (now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none(),
+ optional.of(@index3.format([@index0, @index2]))],
+ (now.getHours() >= 20)
+ ? ((now.getHours() < 21) ? optional.of(@index4 + "!") :
+ ((now.getHours() < 22) ? optional.of(@index4 + "!!") : @index5))
+ : @index6)`,
+ },
+ {
+ name: "limits",
+ composerOpts: []ComposerOption{ExpressionUnnestHeight(4)},
+ composed: `
+ cel.@block([
+ "hello",
+ "goodbye",
+ "me",
+ "%s, %s",
+ @index3.format([@index1, @index2]),
+ (now.getHours() < 22) ? optional.of(@index4 + "!!") :
+ ((now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none())],
+ (now.getHours() >= 20)
+ ? ((now.getHours() < 21) ? optional.of(@index4 + "!") : @index5)
+ : optional.of(@index3.format([@index0, @index2])))
+ `,
+ },
+ {
+ name: "limits",
+ composerOpts: []ComposerOption{ExpressionUnnestHeight(5)},
+ composed: `
+ cel.@block([
+ "hello",
+ "goodbye",
+ "me",
+ "%s, %s",
+ @index3.format([@index1, @index2]),
+ (now.getHours() < 21) ? optional.of(@index4 + "!") :
+ ((now.getHours() < 22) ? optional.of(@index4 + "!!") :
+ ((now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none()))],
+ (now.getHours() >= 20) ? @index5 : optional.of(@index3.format([@index0, @index2])))`,
+ },
+ }
+
policyErrorTests = []struct {
name string
err string
diff --git a/policy/testdata/unnest/config.yaml b/policy/testdata/unnest/config.yaml
new file mode 100644
index 0000000..1891ed6
--- /dev/null
+++ b/policy/testdata/unnest/config.yaml
@@ -0,0 +1,20 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "unnest"
+variables:
+ - name: values
+ type_name: list
+ params:
+ - type_name: int
diff --git a/policy/testdata/unnest/policy.yaml b/policy/testdata/unnest/policy.yaml
new file mode 100644
index 0000000..af63683
--- /dev/null
+++ b/policy/testdata/unnest/policy.yaml
@@ -0,0 +1,32 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: unnest
+rule:
+ variables:
+ - name: even_greater
+ expression: >
+ values.filter(x, x > 2)
+ match:
+ - condition: >
+ variables.even_greater.size() == 0 ? false :
+ variables.even_greater.all(x, x % 2 == 0)
+ output: >
+ "some divisible by 2"
+ - condition: "values.map(x, x * 3).exists(x, x % 4 == 0)"
+ output: >
+ "at least one divisible by 4"
+ - condition: "values.map(x, x * x * x).exists(x, x % 6 == 0)"
+ output: >
+ "at least one power of 6"
diff --git a/policy/testdata/unnest/tests.yaml b/policy/testdata/unnest/tests.yaml
new file mode 100644
index 0000000..9bed7b3
--- /dev/null
+++ b/policy/testdata/unnest/tests.yaml
@@ -0,0 +1,50 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+description: "Unnest tests unnesting of comprehension sequences"
+section:
+ - name: "divisible by 2"
+ tests:
+ - name: "true"
+ input:
+ values:
+ expr: "[4, 6]"
+ output: >
+ "some divisible by 2"
+ - name: "false"
+ input:
+ values:
+ expr: "[1, 3, 5]"
+ output: "optional.none()"
+ - name: "empty-set"
+ input:
+ values:
+ expr: "[1, 2]"
+ output: "optional.none()"
+ - name: "divisible by 4"
+ tests:
+ - name: "true"
+ input:
+ values:
+ expr: "[4, 7]"
+ output: >
+ "at least one divisible by 4"
+ - name: "power of 6"
+ tests:
+ - name: "true"
+ input:
+ values:
+ expr: "[6, 7]"
+ output: >
+ "at least one power of 6"