首页 > 编程语言 > java简单手写版本实现时间轮算法
2021
04-07

java简单手写版本实现时间轮算法

时间轮

关于时间轮的介绍,网上有很多,这里就不重复了

核心思想

  • 一个环形数组存储时间轮的所有槽(看你的手表),每个槽对应当前时间轮的最小精度
  • 超过当前时间轮最大表示范围的会被丢到上层时间轮,上层时间轮的最小精度即为下层时间轮能表达的最大时间(时分秒概念)
  • 每个槽对应一个环形链表存储该时间应该被执行的任务
  • 需要一个线程去驱动指针运转,获取到期任务

以下给出java 简单手写版本实现

代码实现

时间轮主数据结构

/**
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 19:31
 */
@Slf4j
public class TimeWheel {
  /**
   * 一个槽的时间间隔(时间轮最小刻度)
   */
  private long tickMs;

  /**
   * 时间轮大小(槽的个数)
   */
  private int wheelSize;

  /**
   * 一轮的时间跨度
   */
  private long interval;

  private long currentTime;

  /**
   * 槽
   */
  private TimerTaskList[] buckets;

  /**
   * 上层时间轮
   */
  private volatile TimeWheel overflowWheel;

  /**
   * 一个timer只有一个delayqueue
   */
  private DelayQueue<TimerTaskList> delayQueue;

  public TimeWheel(long tickMs, int wheelSize, long currentTime, DelayQueue<TimerTaskList> delayQueue) {
    this.currentTime = currentTime;
    this.tickMs = tickMs;
    this.wheelSize = wheelSize;
    this.interval = tickMs * wheelSize;
    this.buckets = new TimerTaskList[wheelSize];
    this.currentTime = currentTime - (currentTime % tickMs);
    this.delayQueue = delayQueue;
    for (int i = 0; i < wheelSize; i++) {
      buckets[i] = new TimerTaskList();
    }
  }

  public boolean add(TimerTaskEntry entry) {
    long expiration = entry.getExpireMs();
    if (expiration < tickMs + currentTime) {
      //到期了
      return false;
    } else if (expiration < currentTime + interval) {
      //扔进当前时间轮的某个槽里,只有时间大于某个槽,才会放进去
      long virtualId = (expiration / tickMs);
      int index = (int) (virtualId % wheelSize);
      TimerTaskList bucket = buckets[index];
      bucket.addTask(entry);
      //设置bucket 过期时间
      if (bucket.setExpiration(virtualId * tickMs)) {
        //设好过期时间的bucket需要入队
        delayQueue.offer(bucket);
        return true;
      }
    } else {
      //当前轮不能满足,需要扔到上一轮
      TimeWheel timeWheel = getOverflowWheel();
      return timeWheel.add(entry);
    }
    return false;
  }


  private TimeWheel getOverflowWheel() {
    if (overflowWheel == null) {
      synchronized (this) {
        if (overflowWheel == null) {
          overflowWheel = new TimeWheel(interval, wheelSize, currentTime, delayQueue);
        }
      }
    }
    return overflowWheel;
  }

  /**
   * 推进指针
   *
   * @param timestamp
   */
  public void advanceLock(long timestamp) {
    if (timestamp > currentTime + tickMs) {
      currentTime = timestamp - (timestamp % tickMs);
      if (overflowWheel != null) {
        this.getOverflowWheel().advanceLock(timestamp);
      }
    }
  }
}

定时器接口

/**
 * 定时器
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 20:30
 */
public interface Timer {

  /**
   * 添加一个新任务
   *
   * @param timerTask
   */
  void add(TimerTask timerTask);


  /**
   * 推动指针
   *
   * @param timeout
   */
  void advanceClock(long timeout);

  /**
   * 等待执行的任务
   *
   * @return
   */
  int size();

  /**
   * 关闭服务,剩下的无法被执行
   */
  void shutdown();
}

定时器实现

/**
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 20:33
 */
@Slf4j
public class SystemTimer implements Timer {
  /**
   * 底层时间轮
   */
  private TimeWheel timeWheel;
  /**
   * 一个Timer只有一个延时队列
   */
  private DelayQueue<TimerTaskList> delayQueue = new DelayQueue<>();
  /**
   * 过期任务执行线程
   */
  private ExecutorService workerThreadPool;
  /**
   * 轮询delayQueue获取过期任务线程
   */
  private ExecutorService bossThreadPool;


  public SystemTimer() {
    this.timeWheel = new TimeWheel(1, 20, System.currentTimeMillis(), delayQueue);
    this.workerThreadPool = Executors.newFixedThreadPool(100);
    this.bossThreadPool = Executors.newFixedThreadPool(1);
    //20ms推动一次时间轮运转
    this.bossThreadPool.submit(() -> {
      for (; ; ) {
        this.advanceClock(20);
      }
    });
  }


  public void addTimerTaskEntry(TimerTaskEntry entry) {
    if (!timeWheel.add(entry)) {
      //已经过期了
      TimerTask timerTask = entry.getTimerTask();
      log.info("=====任务:{} 已到期,准备执行============",timerTask.getDesc());
      workerThreadPool.submit(timerTask);
    }
  }

  @Override
  public void add(TimerTask timerTask) {
    log.info("=======添加任务开始====task:{}", timerTask.getDesc());
    TimerTaskEntry entry = new TimerTaskEntry(timerTask, timerTask.getDelayMs() + System.currentTimeMillis());
    timerTask.setTimerTaskEntry(entry);
    addTimerTaskEntry(entry);
  }

  /**
   * 推动指针运转获取过期任务
   *
   * @param timeout 时间间隔
   * @return
   */
  @Override
  public synchronized void advanceClock(long timeout) {
    try {
      TimerTaskList bucket = delayQueue.poll(timeout, TimeUnit.MILLISECONDS);
      if (bucket != null) {
        //推进时间
        timeWheel.advanceLock(bucket.getExpiration());
        //执行过期任务(包含降级)
        bucket.clear(this::addTimerTaskEntry);
      }
    } catch (InterruptedException e) {
      log.error("advanceClock error");
    }
  }

  @Override
  public int size() {
    //todo
    return 0;
  }

  @Override
  public void shutdown() {
    this.bossThreadPool.shutdown();
    this.workerThreadPool.shutdown();
    this.timeWheel = null;
  }
}

存储任务的环形链表

/**
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 19:26
 */
@Data
@Slf4j
class TimerTaskList implements Delayed {
  /**
   * TimerTaskList 环形链表使用一个虚拟根节点root
   */
  private TimerTaskEntry root = new TimerTaskEntry(null, -1);

  {
    root.next = root;
    root.prev = root;
  }

  /**
   * bucket的过期时间
   */
  private AtomicLong expiration = new AtomicLong(-1L);

  public long getExpiration() {
    return expiration.get();
  }

  /**
   * 设置bucket的过期时间,设置成功返回true
   *
   * @param expirationMs
   * @return
   */
  boolean setExpiration(long expirationMs) {
    return expiration.getAndSet(expirationMs) != expirationMs;
  }

  public boolean addTask(TimerTaskEntry entry) {
    boolean done = false;
    while (!done) {
      //如果TimerTaskEntry已经在别的list中就先移除,同步代码块外面移除,避免死锁,一直到成功为止
      entry.remove();
      synchronized (this) {
        if (entry.timedTaskList == null) {
          //加到链表的末尾
          entry.timedTaskList = this;
          TimerTaskEntry tail = root.prev;
          entry.prev = tail;
          entry.next = root;
          tail.next = entry;
          root.prev = entry;
          done = true;
        }
      }
    }
    return true;
  }

  /**
   * 从 TimedTaskList 移除指定的 timerTaskEntry
   *
   * @param entry
   */
  public void remove(TimerTaskEntry entry) {
    synchronized (this) {
      if (entry.getTimedTaskList().equals(this)) {
        entry.next.prev = entry.prev;
        entry.prev.next = entry.next;
        entry.next = null;
        entry.prev = null;
        entry.timedTaskList = null;
      }
    }
  }

  /**
   * 移除所有
   */
  public synchronized void clear(Consumer<TimerTaskEntry> entry) {
    TimerTaskEntry head = root.next;
    while (!head.equals(root)) {
      remove(head);
      entry.accept(head);
      head = root.next;
    }
    expiration.set(-1L);
  }

  @Override
  public long getDelay(TimeUnit unit) {
    return Math.max(0, unit.convert(expiration.get() - System.currentTimeMillis(), TimeUnit.MILLISECONDS));
  }

  @Override
  public int compareTo(Delayed o) {
    if (o instanceof TimerTaskList) {
      return Long.compare(expiration.get(), ((TimerTaskList) o).expiration.get());
    }
    return 0;
  }
}

存储任务的容器entry

/**
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 19:26
 */
@Data
class TimerTaskEntry implements Comparable<TimerTaskEntry> {
  private TimerTask timerTask;
  private long expireMs;
  volatile TimerTaskList timedTaskList;
  TimerTaskEntry next;
  TimerTaskEntry prev;

  public TimerTaskEntry(TimerTask timedTask, long expireMs) {
    this.timerTask = timedTask;
    this.expireMs = expireMs;
    this.next = null;
    this.prev = null;
  }

  void remove() {
    TimerTaskList currentList = timedTaskList;
    while (currentList != null) {
      currentList.remove(this);
      currentList = timedTaskList;
    }
  }

  @Override
  public int compareTo(TimerTaskEntry o) {
    return ((int) (this.expireMs - o.expireMs));
  }
}

任务包装类(这里也可以将工作任务以线程变量的方式去传入)

@Data
@Slf4j
class TimerTask implements Runnable {
  /**
   * 延时时间
   */
  private long delayMs;
  /**
   * 任务所在的entry
   */
  private TimerTaskEntry timerTaskEntry;

  private String desc;

  public TimerTask(String desc, long delayMs) {
    this.desc = desc;
    this.delayMs = delayMs;
    this.timerTaskEntry = null;
  }

  public synchronized void setTimerTaskEntry(TimerTaskEntry entry) {
    // 如果这个timetask已经被一个已存在的TimerTaskEntry持有,先移除一个
    if (timerTaskEntry != null && timerTaskEntry != entry) {
      timerTaskEntry.remove();
    }
    timerTaskEntry = entry;
  }

  public TimerTaskEntry getTimerTaskEntry() {
    return timerTaskEntry;
  }

  @Override
  public void run() {
    log.info("============={}任务执行", desc);
  }
}

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持自学编程网。

编程技巧