diff --git a/Include/internal/pycore_optimizer.h b/Include/internal/pycore_optimizer.h index 0592221f15226e..13c91200b0a67d 100644 --- a/Include/internal/pycore_optimizer.h +++ b/Include/internal/pycore_optimizer.h @@ -205,6 +205,8 @@ extern JitOptRef _Py_uop_sym_new_truthiness(JitOptContext *ctx, JitOptRef value, extern bool _Py_uop_sym_is_compact_int(JitOptRef sym); extern JitOptRef _Py_uop_sym_new_compact_int(JitOptContext *ctx); extern void _Py_uop_sym_set_compact_int(JitOptContext *ctx, JitOptRef sym); +extern JitOptRef _Py_uop_sym_new_predicate(JitOptContext *ctx, JitOptRef lhs_ref, JitOptRef rhs_ref, JitOptPredicateKind kind, bool invert); +extern void _Py_uop_sym_apply_predicate_narrowing(JitOptContext *ctx, JitOptRef sym, bool branch_is_true); extern void _Py_uop_abstractcontext_init(JitOptContext *ctx); extern void _Py_uop_abstractcontext_fini(JitOptContext *ctx); diff --git a/Include/internal/pycore_optimizer_types.h b/Include/internal/pycore_optimizer_types.h index 6501ce869c1425..c2ecb8b8a66164 100644 --- a/Include/internal/pycore_optimizer_types.h +++ b/Include/internal/pycore_optimizer_types.h @@ -40,6 +40,7 @@ typedef enum _JitSymType { JIT_SYM_TUPLE_TAG = 8, JIT_SYM_TRUTHINESS_TAG = 9, JIT_SYM_COMPACT_INT = 10, + JIT_SYM_PREDICATE_TAG = 11, } JitSymType; typedef struct _jit_opt_known_class { @@ -72,6 +73,18 @@ typedef struct { uint16_t value; } JitOptTruthiness; +typedef enum { + JIT_PRED_IS, +} JitOptPredicateKind; + +typedef struct { + uint8_t tag; + uint8_t kind; + bool invert; + uint16_t lhs; + uint16_t rhs; +} JitOptPredicate; + typedef struct { uint8_t tag; } JitOptCompactInt; @@ -84,6 +97,7 @@ typedef union _jit_opt_symbol { JitOptTuple tuple; JitOptTruthiness truthiness; JitOptCompactInt compact; + JitOptPredicate predicate; } JitOptSymbol; // This mimics the _PyStackRef API diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py index 79c7f530b8ae89..2f63661cd3f9cc 100644 --- a/Lib/test/test_capi/test_opt.py +++ b/Lib/test/test_capi/test_opt.py @@ -3534,6 +3534,46 @@ def test_is_none(n): self.assertIn("_POP_TOP_NOP", uops) self.assertLessEqual(count_ops(ex, "_POP_TOP"), 2) + def test_is_true_narrows_to_constant(self): + def f(n): + def return_true(): + return True + + hits = 0 + v = return_true() + for i in range(n): + if v is True: + hits += v + 1 + return hits + + res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD) + self.assertEqual(res, TIER2_THRESHOLD * 2) + self.assertIsNotNone(ex) + uops = get_opnames(ex) + + # v + 1 should be constant folded + self.assertNotIn("_BINARY_OP", uops) + + def test_is_false_narrows_to_constant(self): + def f(n): + def return_false(): + return False + + hits = 0 + v = return_false() + for i in range(n): + if v is False: + hits += v + 1 + return hits + + res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD) + self.assertEqual(res, TIER2_THRESHOLD) + self.assertIsNotNone(ex) + uops = get_opnames(ex) + + # v + 1 should be constant folded + self.assertNotIn("_BINARY_OP", uops) + def test_for_iter_gen_frame(self): def f(n): for i in range(n): diff --git a/Python/optimizer_analysis.c b/Python/optimizer_analysis.c index d635ebabf9007a..17b3b7973a35da 100644 --- a/Python/optimizer_analysis.c +++ b/Python/optimizer_analysis.c @@ -247,6 +247,8 @@ add_op(JitOptContext *ctx, _PyUOpInstruction *this_instr, #define sym_is_compact_int _Py_uop_sym_is_compact_int #define sym_new_compact_int _Py_uop_sym_new_compact_int #define sym_new_truthiness _Py_uop_sym_new_truthiness +#define sym_new_predicate _Py_uop_sym_new_predicate +#define sym_apply_predicate_narrowing _Py_uop_sym_apply_predicate_narrowing #define JUMP_TO_LABEL(label) goto label; diff --git a/Python/optimizer_bytecodes.c b/Python/optimizer_bytecodes.c index 876ba7c6de7482..f7ed6292c9114e 100644 --- a/Python/optimizer_bytecodes.c +++ b/Python/optimizer_bytecodes.c @@ -38,6 +38,8 @@ typedef struct _Py_UOpsAbstractFrame _Py_UOpsAbstractFrame; #define sym_new_compact_int _Py_uop_sym_new_compact_int #define sym_is_compact_int _Py_uop_sym_is_compact_int #define sym_new_truthiness _Py_uop_sym_new_truthiness +#define sym_new_predicate _Py_uop_sym_new_predicate +#define sym_apply_predicate_narrowing _Py_uop_sym_apply_predicate_narrowing extern int optimize_to_bool( @@ -533,7 +535,7 @@ dummy_func(void) { } op(_IS_OP, (left, right -- b, l, r)) { - b = sym_new_type(ctx, &PyBool_Type); + b = sym_new_predicate(ctx, left, right, JIT_PRED_IS, oparg != 0); l = left; r = right; } @@ -1142,6 +1144,7 @@ dummy_func(void) { assert(value != NULL); eliminate_pop_guard(this_instr, ctx, value != Py_True); } + sym_apply_predicate_narrowing(ctx, flag, true); sym_set_const(flag, Py_True); } @@ -1187,6 +1190,7 @@ dummy_func(void) { assert(value != NULL); eliminate_pop_guard(this_instr, ctx, value != Py_False); } + sym_apply_predicate_narrowing(ctx, flag, false); sym_set_const(flag, Py_False); } diff --git a/Python/optimizer_cases.c.h b/Python/optimizer_cases.c.h index 012fe16bfd9096..1eedcc31fdcd4f 100644 --- a/Python/optimizer_cases.c.h +++ b/Python/optimizer_cases.c.h @@ -2293,7 +2293,7 @@ JitOptRef r; right = stack_pointer[-1]; left = stack_pointer[-2]; - b = sym_new_type(ctx, &PyBool_Type); + b = sym_new_predicate(ctx, left, right, JIT_PRED_IS, oparg != 0); l = left; r = right; CHECK_STACK_BOUNDS(1); @@ -3720,6 +3720,7 @@ assert(value != NULL); eliminate_pop_guard(this_instr, ctx, value != Py_True); } + sym_apply_predicate_narrowing(ctx, flag, true); sym_set_const(flag, Py_True); CHECK_STACK_BOUNDS(-1); stack_pointer += -1; @@ -3735,6 +3736,7 @@ assert(value != NULL); eliminate_pop_guard(this_instr, ctx, value != Py_False); } + sym_apply_predicate_narrowing(ctx, flag, false); sym_set_const(flag, Py_False); CHECK_STACK_BOUNDS(-1); stack_pointer += -1; diff --git a/Python/optimizer_symbols.c b/Python/optimizer_symbols.c index 5f5086d33b5c4c..361c70d1214363 100644 --- a/Python/optimizer_symbols.c +++ b/Python/optimizer_symbols.c @@ -309,6 +309,7 @@ _Py_uop_sym_set_type(JitOptContext *ctx, JitOptRef ref, PyTypeObject *typ) sym->cls.version = 0; sym->cls.type = typ; return; + case JIT_SYM_PREDICATE_TAG: case JIT_SYM_TRUTHINESS_TAG: if (typ != &PyBool_Type) { sym_set_bottom(ctx, sym); @@ -370,6 +371,7 @@ _Py_uop_sym_set_type_version(JitOptContext *ctx, JitOptRef ref, unsigned int ver sym->tag = JIT_SYM_TYPE_VERSION_TAG; sym->version.version = version; return true; + case JIT_SYM_PREDICATE_TAG: case JIT_SYM_TRUTHINESS_TAG: if (version != PyBool_Type.tp_version_tag) { sym_set_bottom(ctx, sym); @@ -436,6 +438,13 @@ _Py_uop_sym_set_const(JitOptContext *ctx, JitOptRef ref, PyObject *const_val) case JIT_SYM_UNKNOWN_TAG: make_const(sym, const_val); return; + case JIT_SYM_PREDICATE_TAG: + if (!PyBool_Check(const_val)) { + sym_set_bottom(ctx, sym); + return; + } + make_const(sym, const_val); + return; case JIT_SYM_TRUTHINESS_TAG: if (!PyBool_Check(const_val) || (_Py_uop_sym_is_const(ctx, ref) && @@ -589,6 +598,7 @@ _Py_uop_sym_get_type(JitOptRef ref) return _PyType_LookupByVersion(sym->version.version); case JIT_SYM_TUPLE_TAG: return &PyTuple_Type; + case JIT_SYM_PREDICATE_TAG: case JIT_SYM_TRUTHINESS_TAG: return &PyBool_Type; case JIT_SYM_COMPACT_INT: @@ -617,6 +627,7 @@ _Py_uop_sym_get_type_version(JitOptRef ref) return Py_TYPE(sym->value.value)->tp_version_tag; case JIT_SYM_TUPLE_TAG: return PyTuple_Type.tp_version_tag; + case JIT_SYM_PREDICATE_TAG: case JIT_SYM_TRUTHINESS_TAG: return PyBool_Type.tp_version_tag; case JIT_SYM_COMPACT_INT: @@ -810,6 +821,7 @@ _Py_uop_sym_set_compact_int(JitOptContext *ctx, JitOptRef ref) } return; case JIT_SYM_TUPLE_TAG: + case JIT_SYM_PREDICATE_TAG: case JIT_SYM_TRUTHINESS_TAG: sym_set_bottom(ctx, sym); return; @@ -823,6 +835,59 @@ _Py_uop_sym_set_compact_int(JitOptContext *ctx, JitOptRef ref) } } +JitOptRef +_Py_uop_sym_new_predicate(JitOptContext *ctx, JitOptRef lhs_ref, JitOptRef rhs_ref, JitOptPredicateKind kind, bool invert) +{ + JitOptSymbol *lhs = PyJitRef_Unwrap(lhs_ref); + JitOptSymbol *rhs = PyJitRef_Unwrap(rhs_ref); + + JitOptSymbol *res = sym_new(ctx); + if (res == NULL) { + return out_of_space_ref(ctx); + } + + res->tag = JIT_SYM_PREDICATE_TAG; + res->predicate.invert = invert; + res->predicate.kind = kind; + res->predicate.lhs = (uint16_t)(lhs - allocation_base(ctx)); + res->predicate.rhs = (uint16_t)(rhs - allocation_base(ctx)); + + return PyJitRef_Wrap(res); +} + +void +_Py_uop_sym_apply_predicate_narrowing(JitOptContext *ctx, JitOptRef ref, bool branch_is_true) +{ + JitOptSymbol *sym = PyJitRef_Unwrap(ref); + if (sym->tag != JIT_SYM_PREDICATE_TAG) { + return; + } + + JitOptPredicate pred = sym->predicate; + bool narrow = (branch_is_true && !pred.invert) || (!branch_is_true && pred.invert); + if (!narrow) { + return; + } + + JitOptRef lhs_ref = PyJitRef_Wrap(allocation_base(ctx) + pred.lhs); + JitOptRef rhs_ref = PyJitRef_Wrap(allocation_base(ctx) + pred.rhs); + + bool lhs_is_const = _Py_uop_sym_is_safe_const(ctx, lhs_ref); + bool rhs_is_const = _Py_uop_sym_is_safe_const(ctx, rhs_ref); + + if (pred.kind == JIT_PRED_IS && (lhs_is_const || rhs_is_const)) { + JitOptRef subject_ref = lhs_is_const ? rhs_ref : lhs_ref; + JitOptRef const_ref = lhs_is_const ? lhs_ref : rhs_ref; + + PyObject *const_val = _Py_uop_sym_get_const(ctx, const_ref); + if (const_val == NULL) { + return; + } + _Py_uop_sym_set_const(ctx, subject_ref, const_val); + assert(_Py_uop_sym_is_safe_const(ctx, subject_ref)); + } +} + JitOptRef _Py_uop_sym_new_truthiness(JitOptContext *ctx, JitOptRef ref, bool truthy) { @@ -1159,6 +1224,70 @@ _Py_uop_symbols_test(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(ignored)) TEST_PREDICATE(_Py_uop_sym_is_const(ctx, value) == true, "value is not constant"); TEST_PREDICATE(_Py_uop_sym_get_const(ctx, value) == Py_True, "value is not True"); + // Resolving predicate result to True should narrow subject to True + JitOptRef subject = _Py_uop_sym_new_unknown(ctx); + JitOptRef const_true = _Py_uop_sym_new_const(ctx, Py_True); + if (PyJitRef_IsNull(subject) || PyJitRef_IsNull(const_true)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_true, JIT_PRED_IS, false); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == Py_True, "predicate narrowing did not narrow subject to True"); + + // Resolving predicate result to False should not narrow subject + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_true, JIT_PRED_IS, false); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false); + TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject"); + + // Resolving inverted predicate to False should narrow subject to True + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_true, JIT_PRED_IS, true); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing (inverted) did not const-narrow subject"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == Py_True, "predicate narrowing (inverted) did not narrow subject to True"); + + // Resolving inverted predicate to True should not narrow subject + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_true, JIT_PRED_IS, true); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)"); + + // Test narrowing subject to None + subject = _Py_uop_sym_new_unknown(ctx); + JitOptRef const_none = _Py_uop_sym_new_const(ctx, Py_None); + if (PyJitRef_IsNull(subject) || PyJitRef_IsNull(const_none)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_none, JIT_PRED_IS, false); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (None)"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == Py_None, "predicate narrowing did not narrow subject to None"); val_big = PyNumber_Lshift(_PyLong_GetOne(), PyLong_FromLong(66)); if (val_big == NULL) {