Add basic implementation of token bucket for flow-control (#3106)

This commit is contained in:
LearningGp 2023-04-25 22:08:36 +08:00 committed by LearningGp
parent 74f7d184cf
commit 0c97d8f4b1
5 changed files with 368 additions and 0 deletions

View File

@ -0,0 +1,117 @@
/*
* Copyright 1999-2023 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.csp.sentinel.slots.block.flow.tokenbucket;
import com.alibaba.csp.sentinel.util.AssertUtil;
import com.alibaba.csp.sentinel.util.TimeUtil;
/**
* @author LearningGp
*/
public class AbstractTokenBucket implements TokenBucket{
protected final long MAX_UNIT_PRODUCE_NUM = Long.MAX_VALUE;
/**
* Number of tokens left in the bucket
*/
protected volatile long currentTokenNum;
/**
* Time of next production token
*/
protected volatile long nextProduceTime;
/**
* Number of tokens produced per unit of time
*/
protected final long unitProduceNum;
/**
* Maximum number of tokens stored in the bucket
*/
protected final long maxTokenNum;
protected final long intervalInMs;
protected final long startTime;
public AbstractTokenBucket(long unitProduceNum, long maxTokenNum, boolean fullStart, long intervalInMs) {
AssertUtil.isTrue(unitProduceNum > 0 && intervalInMs > 0 && unitProduceNum < MAX_UNIT_PRODUCE_NUM,
"Illegal unitProduceNum or intervalInSeconds");
AssertUtil.isTrue(maxTokenNum > 0, "Illegal maxTokenNum");
this.unitProduceNum = unitProduceNum;
this.maxTokenNum = maxTokenNum;
this.intervalInMs = intervalInMs;
this.startTime = TimeUtil.currentTimeMillis();
this.nextProduceTime = startTime;
if (fullStart) {
this.currentTokenNum = maxTokenNum;
} else {
//The token will be filled when the first request arrives (including the initial token)
this.currentTokenNum = 0;
}
}
@Override
public boolean tryConsume(long tokenNum) {
if (tokenNum <= 0) {
return true;
}
if (tokenNum > maxTokenNum) {
return false;
}
long currentTimestamp = TimeUtil.currentTimeMillis();
refreshCurrentTokenNum(currentTimestamp);
if (tokenNum <= currentTokenNum) {
currentTokenNum -= tokenNum;
return true;
} else {
return false;
}
}
@Override
public void refreshCurrentTokenNum(long currentTimestamp) {
if (nextProduceTime > currentTimestamp) {
return;
}
currentTokenNum = Math.min(maxTokenNum, currentTokenNum + calProducedTokenNum(currentTimestamp));
updateNextProduceTime(currentTimestamp);
}
protected long calProducedTokenNum(long currentTimestamp) {
if (nextProduceTime > currentTimestamp) {
return 0;
}
long nextRefreshUnitCount = (nextProduceTime - startTime) / intervalInMs;
long currentUnitCount = (currentTimestamp - startTime) / intervalInMs;
long unitCount = currentUnitCount - nextRefreshUnitCount + 1;
return unitCount * unitProduceNum;
}
protected void updateNextProduceTime(long currentTimestamp) {
nextProduceTime = intervalInMs - ((currentTimestamp - startTime) % intervalInMs) + currentTimestamp;
}
public long refreshTokenAndGetCurrentTokenNum() {
refreshCurrentTokenNum(TimeUtil.currentTimeMillis());
return currentTokenNum;
}
public long getCurrentTokenNum() {
return currentTokenNum;
}
}

View File

@ -0,0 +1,30 @@
/*
* Copyright 1999-2023 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.csp.sentinel.slots.block.flow.tokenbucket;
/**
* @author LearningGp
*/
public class DefaultTokenBucket extends AbstractTokenBucket{
public DefaultTokenBucket(long unitProduceNum, long maxTokenNum, long intervalInMs){
super(unitProduceNum, maxTokenNum, false, intervalInMs);
}
public DefaultTokenBucket(long unitProduceNum, long maxTokenNum, boolean fullStart, long intervalInMs){
super(unitProduceNum, maxTokenNum, fullStart, intervalInMs);
}
}

View File

@ -0,0 +1,69 @@
/*
* Copyright 1999-2023 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.csp.sentinel.slots.block.flow.tokenbucket;
import com.alibaba.csp.sentinel.util.TimeUtil;
/**
* @author LearningGp
*/
public class StrictTokenBucket extends AbstractTokenBucket{
final private Object refreshLock = new Object();
final private Object consumeLock = new Object();
public StrictTokenBucket(long unitProduceNum, long maxTokenNum, long intervalInMs) {
super(unitProduceNum, maxTokenNum, false, intervalInMs);
}
public StrictTokenBucket(long unitProduceNum, long maxTokenNum, boolean fullStart, long intervalInMs) {
super(unitProduceNum, maxTokenNum, fullStart, intervalInMs);
}
@Override
public boolean tryConsume(long tokenNum) {
if (tokenNum > maxTokenNum) {
return false;
}
long currentTimestamp = TimeUtil.currentTimeMillis();
refreshCurrentTokenNum(currentTimestamp);
if (tokenNum <= currentTokenNum) {
synchronized (consumeLock) {
if (tokenNum <= currentTokenNum) {
currentTokenNum -= tokenNum;
return true;
}
}
}
return false;
}
@Override
public void refreshCurrentTokenNum(long currentTimestamp) {
if (nextProduceTime > currentTimestamp) {
return;
}
long producedTokenNum = calProducedTokenNum(currentTimestamp);
synchronized (refreshLock) {
if (nextProduceTime > currentTimestamp) {
return;
}
currentTokenNum = Math.min(maxTokenNum, currentTokenNum + producedTokenNum);
updateNextProduceTime(currentTimestamp);
}
}
}

View File

@ -0,0 +1,27 @@
/*
* Copyright 1999-2023 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.csp.sentinel.slots.block.flow.tokenbucket;
/**
* @author LearningGp
*/
public interface TokenBucket {
boolean tryConsume(long tokenNum);
void refreshCurrentTokenNum(long timestamp);
}

View File

@ -0,0 +1,125 @@
/*
* Copyright 1999-2023 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.csp.sentinel.slots.block.flow.tokenbucket;
import com.alibaba.csp.sentinel.concurrent.NamedThreadFactory;
import com.alibaba.csp.sentinel.test.AbstractTimeBasedTest;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
/**
* @author LearningGp
*/
public class TokenBucketTest extends AbstractTimeBasedTest {
private static ThreadPoolExecutor threadPoolExecutor;
@BeforeClass
public static void beforeClass() throws Exception {
threadPoolExecutor = new ThreadPoolExecutor(64, 64, 0,
TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(),
new NamedThreadFactory("sentinel-token-bucket-test", true),
new ThreadPoolExecutor.AbortPolicy());
}
@AfterClass
public static void afterClass() throws Exception {
threadPoolExecutor.shutdownNow();
}
@Test
public void testForDefaultTokenBucket() throws InterruptedException {
long unitProduceNum = 1;
long maxTokenNum = 2;
long intervalInMs = 1000;
long testStart = System.currentTimeMillis();
setCurrentMillis(testStart);
DefaultTokenBucket defaultTokenBucket = new DefaultTokenBucket(unitProduceNum, maxTokenNum, intervalInMs);
assertTrue(defaultTokenBucket.tryConsume(1));
assertFalse(defaultTokenBucket.tryConsume(1));
DefaultTokenBucket defaultTokenBucketFullStart = new DefaultTokenBucket(unitProduceNum, maxTokenNum,
true, intervalInMs);
assertTrue(defaultTokenBucketFullStart.tryConsume(2));
assertFalse(defaultTokenBucketFullStart.tryConsume(1));
sleep(1000);
assertTrue(defaultTokenBucket.tryConsume(1));
assertFalse(defaultTokenBucket.tryConsume(1));
sleep(1000);
assertTrue(defaultTokenBucketFullStart.tryConsume(2));
assertFalse(defaultTokenBucketFullStart.tryConsume(1));
}
@Test
public void testForStrictTokenBucket() throws InterruptedException {
long unitProduceNum = 5;
long maxTokenNum = 10;
long intervalInMs = 1000;
final int n = 64;
long testStart = System.currentTimeMillis();
setCurrentMillis(testStart);
final AtomicLong passNum = new AtomicLong();
final AtomicLong passNumFullStart = new AtomicLong();
final CountDownLatch countDownLatch = new CountDownLatch(n);
final CountDownLatch countDownLatchFullStart = new CountDownLatch(n);
final StrictTokenBucket strictTokenBucket = new StrictTokenBucket(unitProduceNum, maxTokenNum, intervalInMs);
final StrictTokenBucket strictTokenBucketFullStart = new StrictTokenBucket(unitProduceNum, maxTokenNum, true,
intervalInMs);
for (int i = 0; i < n; i++) {
threadPoolExecutor.execute(new Runnable() {
@Override
public void run() {
if (strictTokenBucket.tryConsume(1)) {
passNum.incrementAndGet();
}
countDownLatch.countDown();
}
});
threadPoolExecutor.execute(new Runnable() {
@Override
public void run() {
if (strictTokenBucketFullStart.tryConsume(1)) {
passNumFullStart.incrementAndGet();
}
countDownLatchFullStart.countDown();
}
});
}
countDownLatch.await();
countDownLatchFullStart.await();
assertEquals(5, passNum.longValue());
assertEquals(10, passNumFullStart.longValue());
}
}