peerguard/fileshare/downloader.go
2024-08-03 19:47:23 +08:00

149 lines
3.4 KiB
Go

package fileshare
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/rkonfj/peerguard/disco"
"github.com/rkonfj/peerguard/rdt"
)
type FileHandle struct {
Filename string
c net.Conn
index uint16
fSize uint32
f io.Reader
}
func (h *FileHandle) Handshake(offset uint32, sha256Checksum []byte) error {
_, err := h.c.Write(buildGet(h.index, offset, sha256Checksum))
if err != nil {
return err
}
header := make([]byte, 5)
h.c.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err = io.ReadFull(h.c, header)
if err != nil {
return fmt.Errorf("read header: %w", err)
}
h.c.SetReadDeadline(time.Time{})
switch header[0] {
case 0:
case 20:
case 1:
return errors.New("bad request. maybe the version is lower than peer")
case 2:
return errors.New("file not found")
case 4:
return errors.New("download file size is less than local file")
case 5:
return errors.New("local file is not part of the file to be downloaded")
default:
return errors.New("invalid protocol header")
}
if offset > 0 && header[0] != 20 {
return errors.New("sha256 checksum non matched for [0, offset)")
}
h.fSize = binary.BigEndian.Uint32(header[1:])
h.f = io.LimitReader(h.c, int64(h.fSize-offset))
return nil
}
func (h *FileHandle) File() (io.Reader, uint32, error) {
if h.f == nil {
return nil, 0, errors.New("handshake first")
}
return h.f, h.fSize, nil
}
func (h *FileHandle) Sha256() ([]byte, error) {
checksum := make([]byte, 32)
if _, err := io.ReadFull(h.c, checksum); err != nil {
return nil, fmt.Errorf("read checksum failed: %w", err)
}
return checksum, nil
}
type Read func(f *FileHandle) error
type Downloader struct {
Network string
Server string
PrivateKey string
ListenUDPPort int
}
func (d *Downloader) Request(ctx context.Context, shareURL string, read Read) error {
pnet := PublicNetwork{Name: d.Network, Server: d.Server, PrivateKey: d.PrivateKey}
packetConn, err := pnet.ListenPacket(d.ListenUDPPort)
if err != nil {
return fmt.Errorf("listen p2p packet failed: %w", err)
}
listener, err := rdt.Listen(packetConn, rdt.EnableStatsServer(fmt.Sprintf(":%d", d.ListenUDPPort+100)))
if err != nil {
return fmt.Errorf("listen rdt: %w", err)
}
resourceURL, err := url.Parse(shareURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
dir, filename := path.Split(resourceURL.Path)
index, err := strconv.ParseInt(strings.Trim(dir, "/"), 10, 16)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
fn, err := url.QueryUnescape(filename)
if err != nil {
fn = filename
}
conn, err := listener.OpenStream(disco.PeerID(resourceURL.Host))
if err != nil {
return fmt.Errorf("dial server failed: %w", err)
}
defer conn.Close()
go func() { // watch exit program event
<-ctx.Done()
conn.Write(buildClose())
conn.Close()
}()
defer conn.Write(buildClose())
return read(&FileHandle{
Filename: fn,
c: conn,
index: uint16(index),
})
}
func buildGet(index uint16, partSize uint32, checksum []byte) []byte {
header := []byte{0, 0}
if partSize > 0 {
header[1] = 36
}
header = append(header, binary.BigEndian.AppendUint16(nil, index)...)
if partSize > 0 {
header = append(header, binary.BigEndian.AppendUint32(nil, partSize)...)
header = append(header, checksum...)
}
return header
}
func buildClose() []byte {
return []byte{1}
}