package cache
import (
"hash/fnv"
"net"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/cache"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type Cache struct {
Next plugin.Handler
Zones []string
zonesMetricLabel string
viewMetricLabel string
ncache *cache.Cache
ncap int
nttl time.Duration
minnttl time.Duration
pcache *cache.Cache
pcap int
pttl time.Duration
minpttl time.Duration
failttl time.Duration
prefetch int
duration time.Duration
percentage int
staleUpTo time.Duration
verifyStale bool
pexcept []string
nexcept []string
keepttl bool
now func() time.Time
}
func New() *Cache {
return &Cache{
Zones: []string{"."},
pcap: defaultCap,
pcache: cache.New(defaultCap),
pttl: maxTTL,
minpttl: minTTL,
ncap: defaultCap,
ncache: cache.New(defaultCap),
nttl: maxNTTL,
minnttl: minNTTL,
failttl: minNTTL,
prefetch: 0,
duration: 1 * time.Minute,
percentage: 10,
now: time.Now,
}
}
func key(qname string, m *dns.Msg, t response.Type, do, cd bool) (bool, uint64) {
if m.Truncated {
return false, 0
}
if t == response.OtherError || t == response.Meta || t == response.Update {
return false, 0
}
return true, hash(qname, m.Question[0].Qtype, do, cd)
}
var one = []byte("1")
var zero = []byte("0")
func hash(qname string, qtype uint16, do, cd bool) uint64 {
h := fnv.New64()
if do {
h.Write(one)
} else {
h.Write(zero)
}
if cd {
h.Write(one)
} else {
h.Write(zero)
}
h.Write([]byte{byte(qtype >> 8)})
h.Write([]byte{byte(qtype)})
h.Write([]byte(qname))
return h.Sum64()
}
func computeTTL(msgTTL, minTTL, maxTTL time.Duration) time.Duration {
ttl := msgTTL
if ttl < minTTL {
ttl = minTTL
}
if ttl > maxTTL {
ttl = maxTTL
}
return ttl
}
type ResponseWriter struct {
dns.ResponseWriter
*Cache
state request.Request
server string
do bool
cd bool
ad bool
prefetch bool
remoteAddr net.Addr
wildcardFunc func() string
pexcept []string
nexcept []string
}
func newPrefetchResponseWriter(server string, state request.Request, c *Cache) *ResponseWriter {
addr := state.W.RemoteAddr()
if u, ok := addr.(*net.UDPAddr); ok {
addr = &net.TCPAddr{IP: u.IP, Port: u.Port, Zone: u.Zone}
}
return &ResponseWriter{
ResponseWriter: state.W,
Cache: c,
state: state,
server: server,
do: state.Do(),
cd: state.Req.CheckingDisabled,
prefetch: true,
remoteAddr: addr,
}
}
func (w *ResponseWriter) RemoteAddr() net.Addr {
if w.remoteAddr != nil {
return w.remoteAddr
}
return w.ResponseWriter.RemoteAddr()
}
func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
mt, _ := response.Typify(res, w.now().UTC())
hasKey, key := key(w.state.Name(), res, mt, w.do, w.cd)
msgTTL := dnsutil.MinimalTTL(res, mt)
var duration time.Duration
switch mt {
case response.NameError, response.NoData:
duration = computeTTL(msgTTL, w.minnttl, w.nttl)
case response.ServerError:
duration = w.failttl
default:
duration = computeTTL(msgTTL, w.minpttl, w.pttl)
}
if hasKey && duration > 0 {
if w.state.Match(res) {
w.set(res, key, mt, duration)
cacheSize.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.pcache.Len()))
cacheSize.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.ncache.Len()))
} else {
cacheDrops.WithLabelValues(w.server, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
}
if w.prefetch {
return nil
}
ttl := uint32(duration.Seconds())
res.Answer = filterRRSlice(res.Answer, ttl, false)
res.Ns = filterRRSlice(res.Ns, ttl, false)
res.Extra = filterRRSlice(res.Extra, ttl, false)
if !w.do && !w.ad {
res.AuthenticatedData = false
}
return w.ResponseWriter.WriteMsg(res)
}
func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) {
switch mt {
case response.NoError, response.Delegation:
if plugin.Zones(w.pexcept).Matches(m.Question[0].Name) != "" {
return
}
i := newItem(m, w.now(), duration)
if w.wildcardFunc != nil {
i.wildcard = w.wildcardFunc()
}
if w.pcache.Add(key, i) {
evictions.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
if w.prefetch {
w.ncache.Remove(key)
}
case response.NameError, response.NoData, response.ServerError:
if plugin.Zones(w.nexcept).Matches(m.Question[0].Name) != "" {
return
}
i := newItem(m, w.now(), duration)
if w.wildcardFunc != nil {
i.wildcard = w.wildcardFunc()
}
if w.ncache.Add(key, i) {
evictions.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
case response.OtherError:
default:
log.Warningf("Caching called with unknown classification: %d", mt)
}
}
func (w *ResponseWriter) Write(buf []byte) (int, error) {
log.Warning("Caching called with Write: not caching reply")
if w.prefetch {
return 0, nil
}
n, err := w.ResponseWriter.Write(buf)
return n, err
}
type verifyStaleResponseWriter struct {
*ResponseWriter
refreshed bool
}
func newVerifyStaleResponseWriter(w *ResponseWriter) *verifyStaleResponseWriter {
return &verifyStaleResponseWriter{
w,
false,
}
}
func (w *verifyStaleResponseWriter) WriteMsg(res *dns.Msg) error {
w.refreshed = false
if res.Rcode == dns.RcodeSuccess || res.Rcode == dns.RcodeNameError {
w.refreshed = true
return w.ResponseWriter.WriteMsg(res)
}
return nil
}
const (
maxTTL = dnsutil.MaximumDefaulTTL
minTTL = dnsutil.MinimalDefaultTTL
maxNTTL = dnsutil.MaximumDefaulTTL / 2
minNTTL = dnsutil.MinimalDefaultTTL
defaultCap = 10000
Success = "success"
Denial = "denial"
)