package smux import ( "bytes" "io" "net" "sync" "sync/atomic" "time" "github.com/pkg/errors" ) // Stream implements net.Conn type Stream struct { id uint32 rstflag int32 sess *Session buffer bytes.Buffer bufferLock sync.Mutex frameSize int chReadEvent chan struct{} // notify a read event die chan struct{} // flag the stream has closed dieLock sync.Mutex readDeadline atomic.Value writeDeadline atomic.Value } // newStream initiates a Stream struct func newStream(id uint32, frameSize int, sess *Session) *Stream { s := new(Stream) s.id = id s.chReadEvent = make(chan struct{}, 1) s.frameSize = frameSize s.sess = sess s.die = make(chan struct{}) return s } // ID returns the unique stream ID. func (s *Stream) ID() uint32 { return s.id } // Read implements net.Conn func (s *Stream) Read(b []byte) (n int, err error) { var deadline <-chan time.Time if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(d.Sub(time.Now())) defer timer.Stop() deadline = timer.C } READ: select { case <-s.die: return 0, errors.New(errBrokenPipe) case <-deadline: return n, errTimeout default: } s.bufferLock.Lock() n, err = s.buffer.Read(b) s.bufferLock.Unlock() if n > 0 { s.sess.returnTokens(n) return n, nil } else if atomic.LoadInt32(&s.rstflag) == 1 { _ = s.Close() return 0, io.EOF } select { case <-s.chReadEvent: goto READ case <-deadline: return n, errTimeout case <-s.die: return 0, errors.New(errBrokenPipe) } } // Write implements net.Conn func (s *Stream) Write(b []byte) (n int, err error) { var deadline <-chan time.Time if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { timer := time.NewTimer(d.Sub(time.Now())) defer timer.Stop() deadline = timer.C } select { case <-s.die: return 0, errors.New(errBrokenPipe) default: } frames := s.split(b, cmdPSH, s.id) sent := 0 for k := range frames { req := writeRequest{ frame: frames[k], result: make(chan writeResult, 1), } select { case s.sess.writes <- req: case <-s.die: return sent, errors.New(errBrokenPipe) case <-deadline: return sent, errTimeout } select { case result := <-req.result: sent += result.n if result.err != nil { return sent, result.err } case <-s.die: return sent, errors.New(errBrokenPipe) case <-deadline: return sent, errTimeout } } return sent, nil } // Close implements net.Conn func (s *Stream) Close() error { s.dieLock.Lock() select { case <-s.die: s.dieLock.Unlock() return errors.New(errBrokenPipe) default: close(s.die) s.dieLock.Unlock() s.sess.streamClosed(s.id) _, err := s.sess.writeFrame(newFrame(cmdFIN, s.id)) return err } } // SetReadDeadline sets the read deadline as defined by // net.Conn.SetReadDeadline. // A zero time value disables the deadline. func (s *Stream) SetReadDeadline(t time.Time) error { s.readDeadline.Store(t) return nil } // SetWriteDeadline sets the write deadline as defined by // net.Conn.SetWriteDeadline. // A zero time value disables the deadline. func (s *Stream) SetWriteDeadline(t time.Time) error { s.writeDeadline.Store(t) return nil } // SetDeadline sets both read and write deadlines as defined by // net.Conn.SetDeadline. // A zero time value disables the deadlines. func (s *Stream) SetDeadline(t time.Time) error { if err := s.SetReadDeadline(t); err != nil { return err } if err := s.SetWriteDeadline(t); err != nil { return err } return nil } // session closes the stream func (s *Stream) sessionClose() { s.dieLock.Lock() defer s.dieLock.Unlock() select { case <-s.die: default: close(s.die) } } // LocalAddr satisfies net.Conn interface func (s *Stream) LocalAddr() net.Addr { if ts, ok := s.sess.conn.(interface { LocalAddr() net.Addr }); ok { return ts.LocalAddr() } return nil } // RemoteAddr satisfies net.Conn interface func (s *Stream) RemoteAddr() net.Addr { if ts, ok := s.sess.conn.(interface { RemoteAddr() net.Addr }); ok { return ts.RemoteAddr() } return nil } // pushBytes a slice into buffer func (s *Stream) pushBytes(p []byte) { s.bufferLock.Lock() s.buffer.Write(p) s.bufferLock.Unlock() } // recycleTokens transform remaining bytes to tokens(will truncate buffer) func (s *Stream) recycleTokens() (n int) { s.bufferLock.Lock() n = s.buffer.Len() s.buffer.Reset() s.bufferLock.Unlock() return } // split large byte buffer into smaller frames, reference only func (s *Stream) split(bts []byte, cmd byte, sid uint32) []Frame { frames := make([]Frame, 0, len(bts)/s.frameSize+1) for len(bts) > s.frameSize { frame := newFrame(cmd, sid) frame.data = bts[:s.frameSize] bts = bts[s.frameSize:] frames = append(frames, frame) } if len(bts) > 0 { frame := newFrame(cmd, sid) frame.data = bts frames = append(frames, frame) } return frames } // notify read event func (s *Stream) notifyReadEvent() { select { case s.chReadEvent <- struct{}{}: default: } } // mark this stream has been reset func (s *Stream) markRST() { atomic.StoreInt32(&s.rstflag, 1) } var errTimeout error = &timeoutError{} type timeoutError struct{} func (e *timeoutError) Error() string { return "i/o timeout" } func (e *timeoutError) Timeout() bool { return true } func (e *timeoutError) Temporary() bool { return true }