aboutsummaryrefslogtreecommitdiff
path: root/cmd/mybittorrent/peer.go
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/mybittorrent/peer.go')
-rw-r--r--cmd/mybittorrent/peer.go69
1 files changed, 43 insertions, 26 deletions
diff --git a/cmd/mybittorrent/peer.go b/cmd/mybittorrent/peer.go
index 3b40493..aee5f26 100644
--- a/cmd/mybittorrent/peer.go
+++ b/cmd/mybittorrent/peer.go
@@ -7,8 +7,8 @@ import (
"errors"
"fmt"
"io"
+ "log/slog"
"net"
- "os"
)
type HandshakeMessage []byte
@@ -17,7 +17,7 @@ type Peer struct {
conn net.Conn
handshake HandshakeMessage
ct *ClientTorrent
- msgCh chan *IncomingMessage
+ unchoked bool
}
func (p *Peer) PeerIdHexString() string {
@@ -37,31 +37,36 @@ func (p *Peer) Close() error {
return nil
}
-func (p *Peer) DownloadPiece(outFile string, index int) error {
- if index >= p.ct.Meta.PieceCount() {
+func (p *Peer) DownloadPiece(w io.Writer, piece Piece) error {
+ if piece.Index >= p.ct.Meta.PieceCount() {
return nil
}
var data = new(bytes.Buffer)
- blockLens := p.ct.Meta.BlockLens(index)
var blockIndex = 0
requestFn := func() error {
- if blockIndex == len(blockLens) {
+ if blockIndex == len(piece.Blocks) {
return nil
}
r := RequestPayload{
- Index: uint32(index),
+ Index: uint32(piece.Index),
Begin: uint32(blockIndex * BlockSize),
- Length: blockLens[blockIndex],
+ Length: piece.Blocks[blockIndex],
}
return p.WriteMessage(MessageTypeRequest, r.Bytes())
}
+ if p.unchoked {
+ if err := requestFn(); err != nil {
+ return err
+ }
+ }
+
for {
- if blockIndex == len(blockLens) {
+ if blockIndex == len(piece.Blocks) {
break
}
@@ -85,6 +90,7 @@ func (p *Peer) DownloadPiece(outFile string, index int) error {
if err = requestFn(); err != nil {
return err
}
+ p.unchoked = true
case MessageTypePiece:
var block BlockPayload
@@ -105,31 +111,24 @@ func (p *Peer) DownloadPiece(outFile string, index int) error {
return err
}
+ case MessageTypeChoke:
+ p.unchoked = false
+
default:
return fmt.Errorf("unimplemented message type: %d", msg.MessageType)
}
}
- if !p.ct.Meta.CheckHash(index, data.Bytes()) {
- return fmt.Errorf("invalid hash value")
- }
-
- out, err := os.Create(outFile)
- if err != nil {
- return errors.Join(fmt.Errorf("failed to create file"), err)
- }
+ slog.Debug("piece downloaded", "piece index", piece.Index)
- _, err = out.Write(data.Bytes())
- if err != nil {
- return errors.Join(fmt.Errorf("failed to write file"), err)
+ if err := piece.CheckHash(data.Bytes()); err != nil {
+ return fmt.Errorf("invalid hash value: %v", err)
}
- err = out.Close()
- if err != nil {
- return errors.Join(fmt.Errorf("failed to close file"), err)
- }
-
- return nil
+ slog.Debug("hash ok", "piece index", piece.Index)
+ _, err := w.Write(data.Bytes())
+ slog.Debug("written to writer", "piece index", piece.Index)
+ return err
}
func (p *Peer) ReadMessage() (*IncomingMessage, error) {
@@ -179,3 +178,21 @@ func (p *Peer) WriteMessage(t MessageType, payload []byte) error {
_, err := msg.Write(payload)
return err
}
+
+func (p *Peer) Download(pieceCh chan Piece, fileCh chan FileResult) {
+ for piece := range pieceCh {
+ fw := FileWriter{
+ ch: fileCh,
+ piece: piece,
+ }
+
+ err := p.DownloadPiece(&fw, piece)
+ if err != nil {
+ slog.Error("failed to download piece", "piece", piece, "err", err)
+ pieceCh <- piece
+ continue
+ }
+
+ slog.Debug("downloaded piece", "piece index", piece.Index, "left", len(pieceCh))
+ }
+}