From 05eb646848065dc0f03ba3bd745a371fec5b648a Mon Sep 17 00:00:00 2001 From: Alex Towle Date: Wed, 21 Aug 2019 13:29:50 -0700 Subject: [PATCH] `@0x:contracts-utils` Added a Refundable contract --- contracts/utils/contracts/src/Refundable.sol | 10 +- .../utils/contracts/test/TestRefundable.sol | 62 ++++--- .../contracts/test/TestRefundableReceiver.sol | 168 ++++++++++++++++++ contracts/utils/package.json | 4 +- contracts/utils/src/artifacts.ts | 2 + contracts/utils/src/wrappers.ts | 1 + contracts/utils/test/refundable.ts | 159 ++++++++++------- contracts/utils/tsconfig.json | 1 + 8 files changed, 313 insertions(+), 94 deletions(-) create mode 100644 contracts/utils/contracts/test/TestRefundableReceiver.sol diff --git a/contracts/utils/contracts/src/Refundable.sol b/contracts/utils/contracts/src/Refundable.sol index a0daf1bfcb..b00b08cfe1 100644 --- a/contracts/utils/contracts/src/Refundable.sol +++ b/contracts/utils/contracts/src/Refundable.sol @@ -24,12 +24,20 @@ contract Refundable { // This bool is used by the refund modifier to allow for lazily evaluated refunds. bool internal shouldNotRefund; - modifier refund { + modifier refundFinalBalance { + _; + if (!shouldNotRefund) { + msg.sender.transfer(address(this).balance); + } + } + + modifier disableRefundUntilEnd { if (shouldNotRefund) { _; } else { shouldNotRefund = true; _; + shouldNotRefund = false; msg.sender.transfer(address(this).balance); } } diff --git a/contracts/utils/contracts/test/TestRefundable.sol b/contracts/utils/contracts/test/TestRefundable.sol index 8d9c16ce59..edb6ade998 100644 --- a/contracts/utils/contracts/test/TestRefundable.sol +++ b/contracts/utils/contracts/test/TestRefundable.sol @@ -24,43 +24,49 @@ import "../src/Refundable.sol"; contract TestRefundable is Refundable { - uint256 public counter = 2; - - function setCounter(uint256 newCounter) + function setShouldNotRefund(bool shouldNotRefundNew) external { - counter = newCounter; + shouldNotRefund = shouldNotRefundNew; } - function complexReentrantRefundFunction() + function getShouldNotRefund() external - payable - refund() + view + returns (bool) { - if (counter == 0) { - // This call tests lazy evaluation across different functions with the refund modifier - this.simpleRefundFunction(); - } else { - counter--; - this.complexReentrantRefundFunction(); - } + return shouldNotRefund; } - function simpleReentrantRefundFunction() - external + function refundFinalBalanceFunction() + public payable - refund() - { - if (counter != 0) { - counter--; - this.simpleReentrantRefundFunction(); - } - } - - function simpleRefundFunction() - external - payable - refund() + refundFinalBalance {} // solhint-disable-line no-empty-blocks + function disableRefundUntilEndFunction() + public + payable + disableRefundUntilEnd + {} // solhint-disable-line no-empty-blocks + + function nestedDisableRefundUntilEndFunction() + public + payable + disableRefundUntilEnd + returns (uint256) + { + disableRefundUntilEndFunction(); + return address(this).balance; + } + + function mixedRefundModifierFunction() + public + payable + disableRefundUntilEnd + returns (uint256) + { + refundFinalBalanceFunction(); + return address(this).balance; + } } diff --git a/contracts/utils/contracts/test/TestRefundableReceiver.sol b/contracts/utils/contracts/test/TestRefundableReceiver.sol new file mode 100644 index 0000000000..621beaf5c4 --- /dev/null +++ b/contracts/utils/contracts/test/TestRefundableReceiver.sol @@ -0,0 +1,168 @@ +/* + + Copyright 2019 ZeroEx Intl. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +pragma solidity ^0.5.9; + +import "./TestRefundable.sol"; + + +contract TestRefundableReceiver { + + /// @dev A payable fallback function is necessary to receive refunds from the `TestRefundable` contract. + function () + external + payable + {} // solhint-disable-line no-empty-blocks + + /// @dev This function tests the behavior to a simple call to `refundFinalBalanceFunction`. This + /// test will verify that the correct refund was provided after the call (depending on whether + /// a refund should be provided), and it will ensure that the `shouldNotRefund` state variable + /// remains unaltered after the function call. + /// @param testRefundable The TestRefundable that should be tested against. + /// @param shouldNotRefund The value that shouldNotRefund should be set to before the call to TestRefundable. + function testRefundFinalBalance( + TestRefundable testRefundable, + bool shouldNotRefund + ) + external + payable + { + // Set `shouldNotRefund` to the specified bool. + testRefundable.setShouldNotRefund(shouldNotRefund); + + // Call `refundFinalBalanceFunction` and forward all value from the contract. + testRefundable.refundFinalBalanceFunction.value(msg.value)(); + + // Assert that the expected refunds happened and that the `shouldNotRefund` value was + // set back to an unaltered state after the call. + requireCorrectFinalBalancesAndState(testRefundable, shouldNotRefund); + } + + /// @dev This function tests the behavior to a simple call to `disableRefundUntilEndFunction`. This + /// test will verify that the correct refund was provided after the call (depending on whether + /// a refund should be provided), and it will ensure that the `shouldNotRefund` state variable + /// remains unaltered after the function call. + /// @param testRefundable The TestRefundable that should be tested against. + /// @param shouldNotRefund The value that shouldNotRefund should be set to before the call to TestRefundable. + function testDisableRefundUntilEnd( + TestRefundable testRefundable, + bool shouldNotRefund + ) + external + payable + { + // Set `shouldNotRefund` to the specified bool. + testRefundable.setShouldNotRefund(shouldNotRefund); + + // Call `disableRefundUntilEndFunction` and forward all value from the contract. + testRefundable.disableRefundUntilEndFunction.value(msg.value)(); + + // Assert that the expected refunds happened and that the `shouldNotRefund` value was + // set back to an unaltered state after the call. + requireCorrectFinalBalancesAndState(testRefundable, shouldNotRefund); + } + + /// @dev This function tests the behavior of a call to a function that has the `disableRefundUntilEndFunction`. + /// The function that is called also uses the `disableRefundUntilEndFunction`, so this function's role is + /// to verify that both the inner and outer modifiers worked correctly. + /// @param testRefundable The TestRefundable that should be tested against. + /// @param shouldNotRefund The value that shouldNotRefund should be set to before the call to TestRefundable. + function testNestedDisableRefundUntilEnd( + TestRefundable testRefundable, + bool shouldNotRefund + ) + external + payable + { + // Set `shouldNotRefund` to the specified bool. + testRefundable.setShouldNotRefund(shouldNotRefund); + + // Call `nestedDisableRefundUntilEndFunction` and forward all value from the contract. + uint256 balanceWithinCall = testRefundable.nestedDisableRefundUntilEndFunction.value(msg.value)(); + + // Ensure that the balance within the call was equal to `msg.value` since the inner refund should + // not have been triggered regardless of the value of `shouldNotRefund`. + require(balanceWithinCall == msg.value, "Incorrect inner balance"); + + // Assert that the expected refunds happened and that the `shouldNotRefund` value was + // set back to an unaltered state after the call. + requireCorrectFinalBalancesAndState(testRefundable, shouldNotRefund); + } + + /// @dev This function tests the behavior of a call to a function that has the `disableRefundUntilEndFunction`. + /// The function that is called uses the `refundFinalBalanceFunction`, so this function's role is + /// to verify that both the inner and outer modifiers worked correctly. + /// @param testRefundable The TestRefundable that should be tested against. + /// @param shouldNotRefund The value that shouldNotRefund should be set to before the call to TestRefundable. + function testMixedRefunds( + TestRefundable testRefundable, + bool shouldNotRefund + ) + external + payable + { + // Set `shouldNotRefund` to the specified bool. + testRefundable.setShouldNotRefund(shouldNotRefund); + + // Call `mixedRefundModifierFunction` and forward all value from the contract. + uint256 balanceWithinCall = testRefundable.mixedRefundModifierFunction.value(msg.value)(); + + // Ensure that the balance within the call was equal to `msg.value` since the inner refund should + // not have been triggered regardless of the value of `shouldNotRefund`. + require(balanceWithinCall == msg.value, "Incorrect inner balance"); + + // Assert that the expected refunds happened and that the `shouldNotRefund` value was + // set back to an unaltered state after the call. + requireCorrectFinalBalancesAndState(testRefundable, shouldNotRefund); + } + + /// @dev This helper function verifies the final balances of this receiver contract and a specified + /// refundable contract and verifies that the `shouldNotRefund` value remains unaltered. + /// @param testRefundable The TestRefundable that should be tested against. + /// @param shouldNotRefund The value that shouldNotRefund was set to before the call to TestRefundable. + function requireCorrectFinalBalancesAndState( + TestRefundable testRefundable, + bool shouldNotRefund + ) + internal + { + // If `shouldNotRefund` was true, then this contract should have a balance of zero, + // and `testRefundable` should have a balance of `msg.value`. Otherwise, the opposite + // should be true. + if (shouldNotRefund) { + // Ensure that this contract's balance is zero. + require(address(this).balance == 0, "Incorrect balance for TestRefundableReceiver"); + + // Ensure that the other contract's balance is equal to `msg.value`. + require(address(testRefundable).balance == msg.value, "Incorrect balance for TestRefundable"); + } else { + // Ensure that this contract's balance is `msg.value`. + require(address(this).balance == msg.value, "Incorrect balance for TestRefundableReceiver"); + + // Ensure that the other contract's balance is equal to zero. + require(address(testRefundable).balance == 0, "Incorrect balance for TestRefundable"); + } + + // Ensure that `shouldNotRefund` in TestRefundable is set to the parameter `shouldNotRefund` + // after the call (i.e. the value didn't change during the function call). + require(testRefundable.getShouldNotRefund() == shouldNotRefund, "Incorrect shouldNotRefund value"); + + // Drain the contract of funds so that subsequent tests don't have to account for leftover ether. + msg.sender.transfer(address(this).balance); + } +} diff --git a/contracts/utils/package.json b/contracts/utils/package.json index 0eae37c102..80096fe2c8 100644 --- a/contracts/utils/package.json +++ b/contracts/utils/package.json @@ -35,8 +35,8 @@ "compile:truffle": "truffle compile" }, "config": { - "abis": "./generated-artifacts/@(Authorizable|IAuthorizable|IOwnable|LibAddress|LibAddressArray|LibAddressArrayRichErrors|LibAuthorizableRichErrors|LibBytes|LibBytesRichErrors|LibEIP1271|LibEIP712|LibOwnableRichErrors|LibReentrancyGuardRichErrors|LibRichErrors|LibSafeMath|LibSafeMathRichErrors|Ownable|ReentrancyGuard|Refundable|SafeMath|TestLibAddress|TestLibAddressArray|TestLibBytes|TestLibEIP712|TestLibRichErrors|TestLogDecoding|TestLogDecodingDownstream|TestOwnable|TestReentrancyGuard|TestRefundable|TestSafeMath).json", - "abis:comment": "This list is auto-generated by contracts-gen. Don't edit manually." + "abis:comment": "This list is auto-generated by contracts-gen. Don't edit manually.", + "abis": "./generated-artifacts/@(Authorizable|IAuthorizable|IOwnable|LibAddress|LibAddressArray|LibAddressArrayRichErrors|LibAuthorizableRichErrors|LibBytes|LibBytesRichErrors|LibEIP1271|LibEIP712|LibOwnableRichErrors|LibReentrancyGuardRichErrors|LibRichErrors|LibSafeMath|LibSafeMathRichErrors|Ownable|ReentrancyGuard|Refundable|SafeMath|TestLibAddress|TestLibAddressArray|TestLibBytes|TestLibEIP712|TestLibRichErrors|TestLogDecoding|TestLogDecodingDownstream|TestOwnable|TestReentrancyGuard|TestRefundable|TestRefundableReceiver|TestSafeMath).json" }, "repository": { "type": "git", diff --git a/contracts/utils/src/artifacts.ts b/contracts/utils/src/artifacts.ts index 7292b47a0c..9d0a7c8831 100644 --- a/contracts/utils/src/artifacts.ts +++ b/contracts/utils/src/artifacts.ts @@ -35,6 +35,7 @@ import * as TestLogDecodingDownstream from '../generated-artifacts/TestLogDecodi import * as TestOwnable from '../generated-artifacts/TestOwnable.json'; import * as TestReentrancyGuard from '../generated-artifacts/TestReentrancyGuard.json'; import * as TestRefundable from '../generated-artifacts/TestRefundable.json'; +import * as TestRefundableReceiver from '../generated-artifacts/TestRefundableReceiver.json'; import * as TestSafeMath from '../generated-artifacts/TestSafeMath.json'; export const artifacts = { Authorizable: Authorizable as ContractArtifact, @@ -67,5 +68,6 @@ export const artifacts = { TestOwnable: TestOwnable as ContractArtifact, TestReentrancyGuard: TestReentrancyGuard as ContractArtifact, TestRefundable: TestRefundable as ContractArtifact, + TestRefundableReceiver: TestRefundableReceiver as ContractArtifact, TestSafeMath: TestSafeMath as ContractArtifact, }; diff --git a/contracts/utils/src/wrappers.ts b/contracts/utils/src/wrappers.ts index 02ec0cfd46..a14f4d30f0 100644 --- a/contracts/utils/src/wrappers.ts +++ b/contracts/utils/src/wrappers.ts @@ -33,4 +33,5 @@ export * from '../generated-wrappers/test_log_decoding_downstream'; export * from '../generated-wrappers/test_ownable'; export * from '../generated-wrappers/test_reentrancy_guard'; export * from '../generated-wrappers/test_refundable'; +export * from '../generated-wrappers/test_refundable_receiver'; export * from '../generated-wrappers/test_safe_math'; diff --git a/contracts/utils/test/refundable.ts b/contracts/utils/test/refundable.ts index dc0180c266..402353977d 100644 --- a/contracts/utils/test/refundable.ts +++ b/contracts/utils/test/refundable.ts @@ -1,85 +1,118 @@ -import { chaiSetup, constants, provider, txDefaults, web3Wrapper } from '@0x/contracts-test-utils'; -import { BlockchainLifecycle } from '@0x/dev-utils'; -import { Web3Wrapper } from '@0x/web3-wrapper'; -import * as chai from 'chai'; +import { blockchainTests } from '@0x/contracts-test-utils'; +import { BigNumber } from '@0x/utils'; import * as _ from 'lodash'; -import { artifacts, TestRefundableContract } from '../src'; +import { artifacts, TestRefundableContract, TestRefundableReceiverContract } from '../src'; -chaiSetup.configure(); -const expect = chai.expect; -const blockchainLifecycle = new BlockchainLifecycle(web3Wrapper); - -describe('Refundable', () => { - let owner: string; - let notOwner: string; - let address: string; +blockchainTests('Refundable', env => { let refundable: TestRefundableContract; + let receiver: TestRefundableReceiverContract; before(async () => { - await blockchainLifecycle.startAsync(); - }); - - after(async () => { - await blockchainLifecycle.revertAsync(); - }); - - before(async () => { - const accounts = await web3Wrapper.getAvailableAddressesAsync(); - [owner, address, notOwner] = _.slice(accounts, 0, 3); + // Create the refundable contract. refundable = await TestRefundableContract.deployFrom0xArtifactAsync( artifacts.TestRefundable, - provider, - txDefaults, + env.provider, + env.txDefaults, + {}, + ); + + // Create the receiver contract. + receiver = await TestRefundableReceiverContract.deployFrom0xArtifactAsync( + artifacts.TestRefundableReceiver, + env.provider, + env.txDefaults, {}, ); }); - beforeEach(async () => { - await blockchainLifecycle.startAsync(); + // The contents of these typescript tests is not adequate to understand the assertions that are made during + // these calls. For a more accurate picture, checkout out "./contracts/test/TestRefundableReceiver.sol". + blockchainTests.resets('refundFinalBalance', async () => { + it('should fully refund the sender when `shouldNotRefund` is false', async () => { + // Send 100 wei to the refundable contract that should be refunded to the receiver contract. + await receiver.testRefundFinalBalance.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(100), + }); + }); + + // This test may not be necessary, but it is included here as a sanity check. + it('should fully refund the sender when `shouldNotRefund` is false for two calls in a row', async () => { + // Send 100 wei to the refundable contract that should be refunded to the receiver contract. + await receiver.testRefundFinalBalance.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(100), + }); + + // Send 1000 wei to the refundable contract that should be refunded to the receiver contract. + await receiver.testRefundFinalBalance.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(1000), + }); + }); + + it('should not refund the sender if `shouldNotRefund` is true', async () => { + /// Send 100 wei to the refundable contract that should not be refunded. + await receiver.testRefundFinalBalance.awaitTransactionSuccessAsync(refundable.address, true, { + value: new BigNumber(1000), + }); + }); }); - afterEach(async () => { - await blockchainLifecycle.revertAsync(); - }); - - describe('refund', async () => { - it('should refund all of the ether sent to the simpleRefundFunction', async () => { - await expect( - refundable.simpleRefundFunction.sendTransactionAsync({ - from: owner, - value: Web3Wrapper.toBaseUnitAmount(1, 18), - }), - ).to.be.fulfilled(''); // tslint:disable-line:await-promise - expect(await web3Wrapper.getBalanceInWeiAsync(refundable.address)).bignumber.to.be.eq( - constants.ZERO_AMOUNT, - ); + // The contents of these typescript tests is not adequate to understand the assertions that are made during + // these calls. For a more accurate picture, checkout out "./contracts/test/TestRefundableReceiver.sol". + blockchainTests.resets('disableRefundUntilEnd', async () => { + it('should fully refund the sender when `shouldNotRefund` is false', async () => { + // Send 100 wei to the refundable contract that should be refunded to the receiver contract. + await receiver.testDisableRefundUntilEnd.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(100), + }); }); - it('should refund all of the ether sent to the simpleReentrantRefundFunction with a counter of 2', async () => { - await expect( - refundable.simpleReentrantRefundFunction.sendTransactionAsync({ - from: owner, - value: Web3Wrapper.toBaseUnitAmount(1, 18), - }), - ).to.be.fulfilled(''); // tslint:disable-line:await-promise - expect(await web3Wrapper.getBalanceInWeiAsync(refundable.address)).bignumber.to.be.eq( - constants.ZERO_AMOUNT, - ); + // This test may not be necessary, but it is included here as a sanity check. + it('should fully refund the sender when `shouldNotRefund` is false for two calls in a row', async () => { + // Send 100 wei to the refundable contract that should be refunded to the receiver contract. + await receiver.testDisableRefundUntilEnd.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(100), + }); + + // Send 1000 wei to the refundable contract that should be refunded to the receiver contract. + await receiver.testDisableRefundUntilEnd.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(1000), + }); }); - it('should refund all of the ether sent to the complexReentrantRefundFunction with a counter of 2', async () => { - await expect( - refundable.complexReentrantRefundFunction.sendTransactionAsync({ - from: owner, - value: Web3Wrapper.toBaseUnitAmount(1, 18), - }), - ).to.be.fulfilled(''); // tslint:disable-line:await-promise - expect(await web3Wrapper.getBalanceInWeiAsync(refundable.address)).bignumber.to.be.eq( - constants.ZERO_AMOUNT, - ); + it('should not refund the sender if `shouldNotRefund` is true', async () => { + /// Send 100 wei to the refundable contract that should not be refunded. + await receiver.testDisableRefundUntilEnd.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(100), + }); }); - // FIXME - Receiver tests + it('should disable the `disableRefundUntilEnd` modifier and refund when `shouldNotRefund` is false', async () => { + /// Send 100 wei to the refundable contract that should be refunded. + await receiver.testNestedDisableRefundUntilEnd.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(100), + }); + }); + + it('should disable the `refundFinalBalance` modifier and send no refund when `shouldNotRefund` is true', async () => { + /// Send 100 wei to the refundable contract that should not be refunded. + await receiver.testNestedDisableRefundUntilEnd.awaitTransactionSuccessAsync(refundable.address, true, { + value: new BigNumber(100), + }); + }); + + it('should disable the `refundFinalBalance` modifier and refund when `shouldNotRefund` is false', async () => { + /// Send 100 wei to the refundable contract that should be refunded. + await receiver.testMixedRefunds.awaitTransactionSuccessAsync(refundable.address, false, { + value: new BigNumber(100), + }); + }); + + it('should disable the `refundFinalBalance` modifier and send no refund when `shouldNotRefund` is true', async () => { + /// Send 100 wei to the refundable contract that should not be refunded. + await receiver.testMixedRefunds.awaitTransactionSuccessAsync(refundable.address, true, { + value: new BigNumber(100), + }); + }); }); }); diff --git a/contracts/utils/tsconfig.json b/contracts/utils/tsconfig.json index 152bf60333..deb1b76698 100644 --- a/contracts/utils/tsconfig.json +++ b/contracts/utils/tsconfig.json @@ -33,6 +33,7 @@ "generated-artifacts/TestOwnable.json", "generated-artifacts/TestReentrancyGuard.json", "generated-artifacts/TestRefundable.json", + "generated-artifacts/TestRefundableReceiver.json", "generated-artifacts/TestSafeMath.json" ], "exclude": ["./deploy/solc/solc_bin"]