contract CallFrameContextAssertion is Assertion {
constructor() {
registerAssertionSpec(AssertionSpec.Reshiram);
}
function triggers() public view override {
registerFnCallTrigger(this.assertCallFrameContext.selector, Protocol.transfer.selector);
}
function assertCallFrameContext() public view {
Protocol protocol = Protocol(ph.getAssertionAdopter());
PhEvm.TriggerContext memory ctx = ph.context();
PhEvm.CallInputs[] memory transferCalls = ph.getCallInputs(
address(protocol),
protocol.transfer.selector
);
PhEvm.CallInputs memory transferCall;
for (uint256 i = 0; i < transferCalls.length; i++) {
if (transferCalls[i].id == ctx.callStart) {
transferCall = transferCalls[i];
break;
}
}
address from = transferCall.caller;
(address to, uint256 amount) = abi.decode(transferCall.input, (address, uint256));
bytes32 fromBalanceSlot = keccak256(abi.encode(from, uint256(0)));
bytes32 toBalanceSlot = keccak256(abi.encode(to, uint256(0)));
PhEvm.ForkId memory preCall = PhEvm.ForkId({forkType: 2, callIndex: ctx.callStart});
PhEvm.ForkId memory postCall = PhEvm.ForkId({forkType: 3, callIndex: ctx.callEnd});
uint256 preFromBalance = uint256(ph.loadStateAt(address(protocol), fromBalanceSlot, preCall));
uint256 preToBalance = uint256(ph.loadStateAt(address(protocol), toBalanceSlot, preCall));
uint256 postFromBalance = uint256(ph.loadStateAt(address(protocol), fromBalanceSlot, postCall));
uint256 postToBalance = uint256(ph.loadStateAt(address(protocol), toBalanceSlot, postCall));
require(postFromBalance == preFromBalance - amount, "From balance mismatch in call frame");
require(postToBalance == preToBalance + amount, "To balance mismatch in call frame");
}
}