Merge pull request #2189 from burkemw3/lib-ify-connections
Decouple connections service from model
This commit is contained in:
commit
b158072a15
|
@ -28,6 +28,7 @@ import (
|
||||||
"github.com/calmh/logger"
|
"github.com/calmh/logger"
|
||||||
"github.com/juju/ratelimit"
|
"github.com/juju/ratelimit"
|
||||||
"github.com/syncthing/syncthing/lib/config"
|
"github.com/syncthing/syncthing/lib/config"
|
||||||
|
"github.com/syncthing/syncthing/lib/connections"
|
||||||
"github.com/syncthing/syncthing/lib/db"
|
"github.com/syncthing/syncthing/lib/db"
|
||||||
"github.com/syncthing/syncthing/lib/discover"
|
"github.com/syncthing/syncthing/lib/discover"
|
||||||
"github.com/syncthing/syncthing/lib/events"
|
"github.com/syncthing/syncthing/lib/events"
|
||||||
|
@ -577,13 +578,6 @@ func syncthingMain() {
|
||||||
symlinks.Supported = false
|
symlinks.Supported = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.MaxSendKbps > 0 {
|
|
||||||
writeRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxSendKbps), int64(5*1000*opts.MaxSendKbps))
|
|
||||||
}
|
|
||||||
if opts.MaxRecvKbps > 0 {
|
|
||||||
readRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxRecvKbps), int64(5*1000*opts.MaxRecvKbps))
|
|
||||||
}
|
|
||||||
|
|
||||||
if (opts.MaxRecvKbps > 0 || opts.MaxSendKbps > 0) && !opts.LimitBandwidthInLan {
|
if (opts.MaxRecvKbps > 0 || opts.MaxSendKbps > 0) && !opts.LimitBandwidthInLan {
|
||||||
lans, _ = osutil.GetLans()
|
lans, _ = osutil.GetLans()
|
||||||
networks := make([]string, 0, len(lans))
|
networks := make([]string, 0, len(lans))
|
||||||
|
@ -750,7 +744,7 @@ func syncthingMain() {
|
||||||
|
|
||||||
// Start connection management
|
// Start connection management
|
||||||
|
|
||||||
connectionSvc := newConnectionSvc(cfg, myID, m, tlsCfg, cachedDiscovery, relaySvc)
|
connectionSvc := connections.NewConnectionSvc(cfg, myID, m, tlsCfg, cachedDiscovery, relaySvc, bepProtocolName, tlsDefaultCommonName, lans)
|
||||||
mainSvc.Add(connectionSvc)
|
mainSvc.Add(connectionSvc)
|
||||||
|
|
||||||
if cpuProfile {
|
if cpuProfile {
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
package main
|
package connections
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -15,6 +15,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/juju/ratelimit"
|
||||||
"github.com/syncthing/syncthing/lib/config"
|
"github.com/syncthing/syncthing/lib/config"
|
||||||
"github.com/syncthing/syncthing/lib/discover"
|
"github.com/syncthing/syncthing/lib/discover"
|
||||||
"github.com/syncthing/syncthing/lib/events"
|
"github.com/syncthing/syncthing/lib/events"
|
||||||
|
@ -35,17 +36,39 @@ var (
|
||||||
listeners = make(map[string]ListenerFactory, 0)
|
listeners = make(map[string]ListenerFactory, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Model interface {
|
||||||
|
AddConnection(conn model.Connection)
|
||||||
|
ConnectedTo(remoteID protocol.DeviceID) bool
|
||||||
|
IsPaused(remoteID protocol.DeviceID) bool
|
||||||
|
|
||||||
|
// An index was received from the peer device
|
||||||
|
Index(deviceID protocol.DeviceID, folder string, files []protocol.FileInfo, flags uint32, options []protocol.Option)
|
||||||
|
// An index update was received from the peer device
|
||||||
|
IndexUpdate(deviceID protocol.DeviceID, folder string, files []protocol.FileInfo, flags uint32, options []protocol.Option)
|
||||||
|
// A request was made by the peer device
|
||||||
|
Request(deviceID protocol.DeviceID, folder string, name string, offset int64, hash []byte, flags uint32, options []protocol.Option, buf []byte) error
|
||||||
|
// A cluster configuration message was received
|
||||||
|
ClusterConfig(deviceID protocol.DeviceID, config protocol.ClusterConfigMessage)
|
||||||
|
// The peer device closed the connection
|
||||||
|
Close(deviceID protocol.DeviceID, err error)
|
||||||
|
}
|
||||||
|
|
||||||
// The connection service listens on TLS and dials configured unconnected
|
// The connection service listens on TLS and dials configured unconnected
|
||||||
// devices. Successful connections are handed to the model.
|
// devices. Successful connections are handed to the model.
|
||||||
type connectionSvc struct {
|
type connectionSvc struct {
|
||||||
*suture.Supervisor
|
*suture.Supervisor
|
||||||
cfg *config.Wrapper
|
cfg *config.Wrapper
|
||||||
myID protocol.DeviceID
|
myID protocol.DeviceID
|
||||||
model *model.Model
|
model Model
|
||||||
tlsCfg *tls.Config
|
tlsCfg *tls.Config
|
||||||
discoverer discover.Finder
|
discoverer discover.Finder
|
||||||
conns chan model.IntermediateConnection
|
conns chan model.IntermediateConnection
|
||||||
relaySvc *relay.Svc
|
relaySvc *relay.Svc
|
||||||
|
bepProtocolName string
|
||||||
|
tlsDefaultCommonName string
|
||||||
|
lans []*net.IPNet
|
||||||
|
writeRateLimit *ratelimit.Bucket
|
||||||
|
readRateLimit *ratelimit.Bucket
|
||||||
|
|
||||||
lastRelayCheck map[protocol.DeviceID]time.Time
|
lastRelayCheck map[protocol.DeviceID]time.Time
|
||||||
|
|
||||||
|
@ -54,16 +77,20 @@ type connectionSvc struct {
|
||||||
relaysEnabled bool
|
relaysEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Model, tlsCfg *tls.Config, discoverer discover.Finder, relaySvc *relay.Svc) *connectionSvc {
|
func NewConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, relaySvc *relay.Svc,
|
||||||
|
bepProtocolName string, tlsDefaultCommonName string, lans []*net.IPNet) suture.Service {
|
||||||
svc := &connectionSvc{
|
svc := &connectionSvc{
|
||||||
Supervisor: suture.NewSimple("connectionSvc"),
|
Supervisor: suture.NewSimple("connectionSvc"),
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
myID: myID,
|
myID: myID,
|
||||||
model: mdl,
|
model: mdl,
|
||||||
tlsCfg: tlsCfg,
|
tlsCfg: tlsCfg,
|
||||||
discoverer: discoverer,
|
discoverer: discoverer,
|
||||||
relaySvc: relaySvc,
|
relaySvc: relaySvc,
|
||||||
conns: make(chan model.IntermediateConnection),
|
conns: make(chan model.IntermediateConnection),
|
||||||
|
bepProtocolName: bepProtocolName,
|
||||||
|
tlsDefaultCommonName: tlsDefaultCommonName,
|
||||||
|
lans: lans,
|
||||||
|
|
||||||
connType: make(map[protocol.DeviceID]model.ConnectionType),
|
connType: make(map[protocol.DeviceID]model.ConnectionType),
|
||||||
relaysEnabled: cfg.Options().RelaysEnabled,
|
relaysEnabled: cfg.Options().RelaysEnabled,
|
||||||
|
@ -71,6 +98,13 @@ func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Mo
|
||||||
}
|
}
|
||||||
cfg.Subscribe(svc)
|
cfg.Subscribe(svc)
|
||||||
|
|
||||||
|
if svc.cfg.Options().MaxSendKbps > 0 {
|
||||||
|
svc.writeRateLimit = ratelimit.NewBucketWithRate(float64(1000*svc.cfg.Options().MaxSendKbps), int64(5*1000*svc.cfg.Options().MaxSendKbps))
|
||||||
|
}
|
||||||
|
if svc.cfg.Options().MaxRecvKbps > 0 {
|
||||||
|
svc.readRateLimit = ratelimit.NewBucketWithRate(float64(1000*svc.cfg.Options().MaxRecvKbps), int64(5*1000*svc.cfg.Options().MaxRecvKbps))
|
||||||
|
}
|
||||||
|
|
||||||
// There are several moving parts here; one routine per listening address
|
// There are several moving parts here; one routine per listening address
|
||||||
// to handle incoming connections, one routine to periodically attempt
|
// to handle incoming connections, one routine to periodically attempt
|
||||||
// outgoing connections, one routine to the the common handling
|
// outgoing connections, one routine to the the common handling
|
||||||
|
@ -97,7 +131,7 @@ func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Mo
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugln("listening on", uri.String())
|
l.Debugln("listening on", uri.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,7 +157,7 @@ next:
|
||||||
// of the TLS handshake. Unfortunately this can't be a hard error,
|
// of the TLS handshake. Unfortunately this can't be a hard error,
|
||||||
// because there are implementations out there that don't support
|
// because there are implementations out there that don't support
|
||||||
// protocol negotiation (iOS for one...).
|
// protocol negotiation (iOS for one...).
|
||||||
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != bepProtocolName {
|
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != s.bepProtocolName {
|
||||||
l.Infof("Peer %s did not negotiate bep/1.0", c.Conn.RemoteAddr())
|
l.Infof("Peer %s did not negotiate bep/1.0", c.Conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,7 +176,7 @@ next:
|
||||||
// The device ID should not be that of ourselves. It can happen
|
// The device ID should not be that of ourselves. It can happen
|
||||||
// though, especially in the presence of NAT hairpinning, multiple
|
// though, especially in the presence of NAT hairpinning, multiple
|
||||||
// clients between the same NAT gateway, and global discovery.
|
// clients between the same NAT gateway, and global discovery.
|
||||||
if remoteID == myID {
|
if remoteID == s.myID {
|
||||||
l.Infof("Connected to myself (%s) - should not happen", remoteID)
|
l.Infof("Connected to myself (%s) - should not happen", remoteID)
|
||||||
c.Conn.Close()
|
c.Conn.Close()
|
||||||
continue
|
continue
|
||||||
|
@ -154,7 +188,7 @@ next:
|
||||||
ct, ok := s.connType[remoteID]
|
ct, ok := s.connType[remoteID]
|
||||||
s.mut.RUnlock()
|
s.mut.RUnlock()
|
||||||
if ok && !ct.IsDirect() && c.Type.IsDirect() {
|
if ok && !ct.IsDirect() && c.Type.IsDirect() {
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugln("Switching connections", remoteID)
|
l.Debugln("Switching connections", remoteID)
|
||||||
}
|
}
|
||||||
s.model.Close(remoteID, fmt.Errorf("switching connections"))
|
s.model.Close(remoteID, fmt.Errorf("switching connections"))
|
||||||
|
@ -181,7 +215,7 @@ next:
|
||||||
// the certificate and used another name.
|
// the certificate and used another name.
|
||||||
certName := deviceCfg.CertName
|
certName := deviceCfg.CertName
|
||||||
if certName == "" {
|
if certName == "" {
|
||||||
certName = tlsDefaultCommonName
|
certName = s.tlsDefaultCommonName
|
||||||
}
|
}
|
||||||
err := remoteCert.VerifyHostname(certName)
|
err := remoteCert.VerifyHostname(certName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -199,20 +233,20 @@ next:
|
||||||
limit := s.shouldLimit(c.Conn.RemoteAddr())
|
limit := s.shouldLimit(c.Conn.RemoteAddr())
|
||||||
|
|
||||||
wr := io.Writer(c.Conn)
|
wr := io.Writer(c.Conn)
|
||||||
if limit && writeRateLimit != nil {
|
if limit && s.writeRateLimit != nil {
|
||||||
wr = &limitedWriter{c.Conn, writeRateLimit}
|
wr = NewWriteLimiter(c.Conn, s.writeRateLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
rd := io.Reader(c.Conn)
|
rd := io.Reader(c.Conn)
|
||||||
if limit && readRateLimit != nil {
|
if limit && s.readRateLimit != nil {
|
||||||
rd = &limitedReader{c.Conn, readRateLimit}
|
rd = NewReadLimiter(c.Conn, s.readRateLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
name := fmt.Sprintf("%s-%s (%s)", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), c.Type)
|
name := fmt.Sprintf("%s-%s (%s)", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), c.Type)
|
||||||
protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
|
protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
|
||||||
|
|
||||||
l.Infof("Established secure connection to %s at %s", remoteID, name)
|
l.Infof("Established secure connection to %s at %s", remoteID, name)
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugf("cipher suite: %04X in lan: %t", c.Conn.ConnectionState().CipherSuite, !limit)
|
l.Debugf("cipher suite: %04X in lan: %t", c.Conn.ConnectionState().CipherSuite, !limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -245,7 +279,7 @@ func (s *connectionSvc) connect() {
|
||||||
for {
|
for {
|
||||||
nextDevice:
|
nextDevice:
|
||||||
for deviceID, deviceCfg := range s.cfg.Devices() {
|
for deviceID, deviceCfg := range s.cfg.Devices() {
|
||||||
if deviceID == myID {
|
if deviceID == s.myID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,12 +325,12 @@ func (s *connectionSvc) connect() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugln("dial", deviceCfg.DeviceID, uri.String())
|
l.Debugln("dial", deviceCfg.DeviceID, uri.String())
|
||||||
}
|
}
|
||||||
conn, err := dialer(uri, s.tlsCfg)
|
conn, err := dialer(uri, s.tlsCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugln("dial failed", deviceCfg.DeviceID, uri.String(), err)
|
l.Debugln("dial failed", deviceCfg.DeviceID, uri.String(), err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
@ -323,11 +357,11 @@ func (s *connectionSvc) connect() {
|
||||||
|
|
||||||
reconIntv := time.Duration(s.cfg.Options().RelayReconnectIntervalM) * time.Minute
|
reconIntv := time.Duration(s.cfg.Options().RelayReconnectIntervalM) * time.Minute
|
||||||
if last, ok := s.lastRelayCheck[deviceID]; ok && time.Since(last) < reconIntv {
|
if last, ok := s.lastRelayCheck[deviceID]; ok && time.Since(last) < reconIntv {
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugln("Skipping connecting via relay to", deviceID, "last checked at", last)
|
l.Debugln("Skipping connecting via relay to", deviceID, "last checked at", last)
|
||||||
}
|
}
|
||||||
continue nextDevice
|
continue nextDevice
|
||||||
} else if debugNet {
|
} else if debug {
|
||||||
l.Debugln("Trying relay connections to", deviceID, relays)
|
l.Debugln("Trying relay connections to", deviceID, relays)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,21 +376,21 @@ func (s *connectionSvc) connect() {
|
||||||
|
|
||||||
inv, err := client.GetInvitationFromRelay(uri, deviceID, s.tlsCfg.Certificates)
|
inv, err := client.GetInvitationFromRelay(uri, deviceID, s.tlsCfg.Certificates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugf("Failed to get invitation for %s from %s: %v", deviceID, uri, err)
|
l.Debugf("Failed to get invitation for %s from %s: %v", deviceID, uri, err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
} else if debugNet {
|
} else if debug {
|
||||||
l.Debugln("Succesfully retrieved relay invitation", inv, "from", uri)
|
l.Debugln("Succesfully retrieved relay invitation", inv, "from", uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := client.JoinSession(inv)
|
conn, err := client.JoinSession(inv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugf("Failed to join relay session %s: %v", inv, err)
|
l.Debugf("Failed to join relay session %s: %v", inv, err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
} else if debugNet {
|
} else if debug {
|
||||||
l.Debugln("Sucessfully joined relay session", inv)
|
l.Debugln("Sucessfully joined relay session", inv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -412,7 +446,7 @@ func (s *connectionSvc) shouldLimit(addr net.Addr) bool {
|
||||||
if !ok {
|
if !ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
for _, lan := range lans {
|
for _, lan := range s.lans {
|
||||||
if lan.Contains(tcpaddr.IP) {
|
if lan.Contains(tcpaddr.IP) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -444,3 +478,10 @@ func (s *connectionSvc) CommitConfiguration(from, to config.Configuration) bool
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serviceFunc wraps a function to create a suture.Service without stop
|
||||||
|
// functionality.
|
||||||
|
type serviceFunc func()
|
||||||
|
|
||||||
|
func (f serviceFunc) Serve() { f() }
|
||||||
|
func (f serviceFunc) Stop() {}
|
|
@ -4,7 +4,7 @@
|
||||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
package main
|
package connections
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -33,7 +33,7 @@ func tcpDialer(uri *url.URL, tlsCfg *tls.Config) (*tls.Conn, error) {
|
||||||
|
|
||||||
raddr, err := net.ResolveTCPAddr("tcp", uri.Host)
|
raddr, err := net.ResolveTCPAddr("tcp", uri.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugln(err)
|
l.Debugln(err)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -41,7 +41,7 @@ func tcpDialer(uri *url.URL, tlsCfg *tls.Config) (*tls.Conn, error) {
|
||||||
|
|
||||||
conn, err := net.DialTCP("tcp", nil, raddr)
|
conn, err := net.DialTCP("tcp", nil, raddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugln(err)
|
l.Debugln(err)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -81,7 +81,7 @@ func tcpListener(uri *url.URL, tlsCfg *tls.Config, conns chan<- model.Intermedia
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if debugNet {
|
if debug {
|
||||||
l.Debugln("connect from", conn.RemoteAddr())
|
l.Debugln("connect from", conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
// Copyright (C) 2014 The Syncthing Authors.
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
package connections
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/calmh/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
debug = strings.Contains(os.Getenv("STTRACE"), "connections") || os.Getenv("STTRACE") == "all"
|
||||||
|
l = logger.DefaultLogger
|
||||||
|
)
|
|
@ -4,7 +4,7 @@
|
||||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
package main
|
package connections
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
@ -12,13 +12,20 @@ import (
|
||||||
"github.com/juju/ratelimit"
|
"github.com/juju/ratelimit"
|
||||||
)
|
)
|
||||||
|
|
||||||
type limitedReader struct {
|
type LimitedReader struct {
|
||||||
r io.Reader
|
reader io.Reader
|
||||||
bucket *ratelimit.Bucket
|
bucket *ratelimit.Bucket
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *limitedReader) Read(buf []byte) (int, error) {
|
func NewReadLimiter(r io.Reader, b *ratelimit.Bucket) *LimitedReader {
|
||||||
n, err := r.r.Read(buf)
|
return &LimitedReader{
|
||||||
|
reader: r,
|
||||||
|
bucket: b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *LimitedReader) Read(buf []byte) (int, error) {
|
||||||
|
n, err := r.reader.Read(buf)
|
||||||
if r.bucket != nil {
|
if r.bucket != nil {
|
||||||
r.bucket.Wait(int64(n))
|
r.bucket.Wait(int64(n))
|
||||||
}
|
}
|
|
@ -4,7 +4,7 @@
|
||||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
package main
|
package connections
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
@ -12,14 +12,21 @@ import (
|
||||||
"github.com/juju/ratelimit"
|
"github.com/juju/ratelimit"
|
||||||
)
|
)
|
||||||
|
|
||||||
type limitedWriter struct {
|
type LimitedWriter struct {
|
||||||
w io.Writer
|
writer io.Writer
|
||||||
bucket *ratelimit.Bucket
|
bucket *ratelimit.Bucket
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *limitedWriter) Write(buf []byte) (int, error) {
|
func NewWriteLimiter(w io.Writer, b *ratelimit.Bucket) *LimitedWriter {
|
||||||
|
return &LimitedWriter{
|
||||||
|
writer: w,
|
||||||
|
bucket: b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *LimitedWriter) Write(buf []byte) (int, error) {
|
||||||
if w.bucket != nil {
|
if w.bucket != nil {
|
||||||
w.bucket.Wait(int64(len(buf)))
|
w.bucket.Wait(int64(len(buf)))
|
||||||
}
|
}
|
||||||
return w.w.Write(buf)
|
return w.writer.Write(buf)
|
||||||
}
|
}
|
Loading…
Reference in New Issue