aboutsummaryrefslogtreecommitdiff
path: root/cmd/mybittorrent
diff options
context:
space:
mode:
authorjet2tlf <jet2tlf@gmail.com>2024-06-03 18:14:24 +0000
committerjet2tlf <jet2tlf@gmail.com>2024-06-03 18:14:24 +0000
commit853be358804a6e30e857035ffda81a06df3f6b74 (patch)
treeeae9b736261ef6887c4070c7bf1e5a441ccb5319 /cmd/mybittorrent
parentadf38a1dbd085c19c4c87ad242e0b340f1655fcb (diff)
downloadbittorrent-go-853be358804a6e30e857035ffda81a06df3f6b74.tar.gz
bittorrent-go-853be358804a6e30e857035ffda81a06df3f6b74.zip
codecrafters submit [skip ci]
Diffstat (limited to 'cmd/mybittorrent')
-rw-r--r--cmd/mybittorrent/client.go51
-rw-r--r--cmd/mybittorrent/main.go54
-rw-r--r--cmd/mybittorrent/message.go76
-rw-r--r--cmd/mybittorrent/meta.go49
-rw-r--r--cmd/mybittorrent/peer.go181
5 files changed, 376 insertions, 35 deletions
diff --git a/cmd/mybittorrent/client.go b/cmd/mybittorrent/client.go
index ef3d860..8b3aac8 100644
--- a/cmd/mybittorrent/client.go
+++ b/cmd/mybittorrent/client.go
@@ -3,7 +3,6 @@ package main
import (
"bytes"
"encoding/binary"
- "encoding/hex"
"fmt"
"io"
"log/slog"
@@ -19,7 +18,7 @@ import (
type Client struct {
PeerId string
Port int
- Torrents map[string]ClientTorrent
+ Torrents map[string]*ClientTorrent
}
type ClientTorrent struct {
@@ -34,17 +33,15 @@ type PeerResponse struct {
Peers []string
}
-type HandshakeMessage []byte
-
func NewClient(peerId string, port int) *Client {
return &Client{
PeerId: peerId,
Port: port,
- Torrents: make(map[string]ClientTorrent),
+ Torrents: make(map[string]*ClientTorrent),
}
}
-func (c *Client) AddTorrentFile(filename string) (ClientTorrent, error) {
+func (c *Client) AddTorrentFile(filename string) error {
f, err := os.Open(filename)
defer func(f *os.File) {
err := f.Close()
@@ -60,19 +57,17 @@ func (c *Client) AddTorrentFile(filename string) (ClientTorrent, error) {
var meta Meta
if err = bencode.Unmarshal(f, &meta); err != nil {
- return ClientTorrent{}, err
+ return err
}
- t := ClientTorrent{
+ c.Torrents[filename] = &ClientTorrent{
Meta: meta,
Uploaded: 0,
Downloaded: 0,
Left: meta.Info.Length,
}
- c.Torrents[filename] = t
-
- return t, nil
+ return nil
}
func (ct ClientTorrent) getUrl(c Client) (string, error) {
@@ -141,13 +136,7 @@ func (c *Client) GetPeers(filename string) (PeerResponse, error) {
}, err
}
-func (m HandshakeMessage) PeerIdHex() string {
-
- return hex.EncodeToString(m[48:])
-
-}
-
-func (c *Client) Handshake(filename, peerAddr string) (HandshakeMessage, error) {
+func (c *Client) Handshake(filename, peerAddr string) (*Peer, error) {
ct, ok := c.Torrents[filename]
if !ok {
return nil, fmt.Errorf("missing torrent file: %s", filename)
@@ -167,12 +156,6 @@ func (c *Client) Handshake(filename, peerAddr string) (HandshakeMessage, error)
buf.WriteString(c.PeerId)
conn, err := net.Dial("tcp", peerAddr)
- defer func(conn net.Conn) {
- err := conn.Close()
- if err != nil {
- slog.Error("failed to close connection", "peerAddr", peerAddr)
- }
- }(conn)
if err != nil {
return nil, err
@@ -184,7 +167,23 @@ func (c *Client) Handshake(filename, peerAddr string) (HandshakeMessage, error)
}
respBuf := make([]byte, 68)
- _, err = io.LimitReader(conn, 68).Read(respBuf)
- return respBuf, nil
+ _, err = io.ReadFull(conn, respBuf)
+
+ peer := &Peer{
+ conn: conn,
+ handshake: respBuf,
+ ct: ct,
+ }
+
+ if !bytes.Equal(peer.InfoHash(), hash) {
+ err := conn.Close()
+ if err != nil {
+ slog.Error("Failed to close peer connection", "remoteAddr", peer.conn.RemoteAddr())
+ }
+
+ return nil, fmt.Errorf("invalid info hash from peer: %x, addr: %s", peer.InfoHash(), peerAddr)
+ }
+
+ return peer, nil
}
diff --git a/cmd/mybittorrent/main.go b/cmd/mybittorrent/main.go
index f000aa5..80a9b7d 100644
--- a/cmd/mybittorrent/main.go
+++ b/cmd/mybittorrent/main.go
@@ -4,17 +4,16 @@ import (
"encoding/json"
"fmt"
"os"
+ "strconv"
"strings"
bencode "github.com/jackpal/bencode-go" // Available if you need it!
)
-func createClient() *Client {
- fn := os.Args[2]
+func createClient(fn string) *Client {
c := NewClient("00112233445566778899", 6881)
- _, err := c.AddTorrentFile(fn)
- if err != nil {
+ if err := c.AddTorrentFile(fn); err != nil {
panic(err)
}
@@ -38,7 +37,7 @@ func main() {
case "info":
fn := os.Args[2]
- c := createClient()
+ c := createClient(fn)
meta := c.Torrents[fn].Meta
fmt.Printf("Tracker URL: %s\n", meta.Announce)
fmt.Printf("Length: %d\n", meta.Info.Length)
@@ -59,7 +58,7 @@ func main() {
case "peers":
fn := os.Args[2]
- c := createClient()
+ c := createClient(fn)
pr, err := c.GetPeers(fn)
if err != nil {
@@ -72,15 +71,52 @@ func main() {
case "handshake":
fn := os.Args[2]
- c := createClient()
+ c := createClient(fn)
peerAddr := os.Args[3]
- hs, err := c.Handshake(fn, peerAddr)
+ peer, err := c.Handshake(fn, peerAddr)
if err != nil {
panic(err)
}
- fmt.Printf("Peer ID: %s\n", hs.PeerIdHex())
+ fmt.Printf("Peer ID: %s\n", peer.PeerIdHexString())
+
+ case "download_piece":
+ out := os.Args[3]
+ fn := os.Args[4]
+
+ index, err := strconv.Atoi(os.Args[5])
+ if err != nil {
+ panic(err)
+ }
+
+ c := createClient(fn)
+ pr, err := c.GetPeers(fn)
+ if err != nil {
+ panic(err)
+ }
+
+ var peer *Peer
+
+ for _, peerAddr := range pr.Peers {
+ peer, err = c.Handshake(fn, peerAddr)
+ if err != nil {
+ continue
+ }
+ }
+
+ if peer == nil {
+ panic(fmt.Errorf("no peers found for file: %s", out))
+ }
+
+ defer peer.Close()
+
+ err = peer.DownloadPiece(out, index)
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Printf("Piece %d downloaded to %s.", index, out)
default:
fmt.Println("Unknown command: " + command)
diff --git a/cmd/mybittorrent/message.go b/cmd/mybittorrent/message.go
new file mode 100644
index 0000000..9439bd5
--- /dev/null
+++ b/cmd/mybittorrent/message.go
@@ -0,0 +1,76 @@
+package main
+
+import (
+ "encoding/binary"
+ "io"
+)
+
+type MessageType byte
+
+type BlockPayload struct {
+ Index uint32
+ Begin uint32
+ Block []byte
+}
+
+type IncomingMessage struct {
+ Len uint32
+ MessageType MessageType
+ Payload []byte
+}
+
+type OutgoingMessage struct {
+ MessageType MessageType
+ Writer io.Writer
+}
+
+type RequestPayload struct {
+ Index uint32
+ Begin uint32
+ Length uint32
+}
+
+const (
+ MessageTypeChoke MessageType = iota
+ MessageTypeUnchoke
+ MessageTypeInterested
+ MessageTypeNotInterested
+ MessageTypeHave
+ MessageTypeBitfield
+ MessageTypeRequest
+ MessageTypePiece
+ MessageTypeCancel
+)
+
+func (o *OutgoingMessage) Write(b []byte) (int, error) {
+ msgLen := 1 + len(b)
+ payloadBuff := make([]byte, msgLen+4)
+ binary.BigEndian.PutUint32(payloadBuff[0:4], uint32(msgLen))
+ payloadBuff[4] = byte(o.MessageType)
+
+ if msgLen > 1 {
+ copy(payloadBuff[5:], b)
+ }
+
+ return o.Writer.Write(payloadBuff)
+}
+
+func (r RequestPayload) Bytes() []byte {
+ buf := make([]byte, 12)
+ binary.BigEndian.PutUint32(buf[0:4], r.Index)
+ binary.BigEndian.PutUint32(buf[4:8], r.Begin)
+ binary.BigEndian.PutUint32(buf[8:12], r.Length)
+ return buf
+}
+
+func (p *BlockPayload) Write(b []byte) (int, error) {
+ p.Index = binary.BigEndian.Uint32(b[0:4])
+ p.Begin = binary.BigEndian.Uint32(b[4:8])
+ p.Block = b[8:]
+ return len(b), nil
+}
+
+func (p *BlockPayload) WriteTo(w io.Writer) (int64, error) {
+ n, err := w.Write(p.Block)
+ return int64(n), err
+}
diff --git a/cmd/mybittorrent/meta.go b/cmd/mybittorrent/meta.go
index 34369c4..27bcc12 100644
--- a/cmd/mybittorrent/meta.go
+++ b/cmd/mybittorrent/meta.go
@@ -1,11 +1,15 @@
package main
import (
+ "bytes"
"crypto/sha1"
+ "math"
bencode "github.com/jackpal/bencode-go"
)
+const BlockSize = 16 * 1024
+
type Meta struct {
Announce string `bencode:"announce"`
Info FileInfo `bencode:"info"`
@@ -35,3 +39,48 @@ func (m Meta) PieceHashes() []string {
return hashes
}
+
+func (m Meta) CheckHash(pieceIndex int, data []byte) bool {
+ sha := sha1.New()
+
+ if _, err := bytes.NewBuffer(data).WriteTo(sha); err != nil {
+ return false
+ }
+
+ return bytes.Equal([]byte(m.Info.Pieces[pieceIndex*20:pieceIndex*20+20]), sha.Sum(nil))
+}
+
+func (m Meta) PieceCount() int {
+ return len(m.Info.Pieces) / 20
+}
+
+func (m Meta) PieceLens() []int {
+ pieceCnt := m.PieceCount()
+ pieces := make([]int, pieceCnt)
+
+ for i := 0; i < pieceCnt; i++ {
+ if i < pieceCnt-1 {
+ pieces[i] = m.Info.PieceLength
+ } else {
+ pieces[i] = m.Info.Length - i*m.Info.PieceLength
+ }
+ }
+
+ return pieces
+}
+
+func (m Meta) BlockLens(pieceIdx int) []uint32 {
+ pieceLen := m.PieceLens()[pieceIdx]
+ blockCnt := int(math.Ceil(float64(pieceLen) / float64(BlockSize)))
+ blocks := make([]uint32, blockCnt)
+
+ for i := 0; i < blockCnt; i++ {
+ if i < blockCnt-1 {
+ blocks[i] = uint32(BlockSize)
+ } else {
+ blocks[i] = uint32(pieceLen - i*BlockSize)
+ }
+ }
+
+ return blocks
+}
diff --git a/cmd/mybittorrent/peer.go b/cmd/mybittorrent/peer.go
new file mode 100644
index 0000000..3b40493
--- /dev/null
+++ b/cmd/mybittorrent/peer.go
@@ -0,0 +1,181 @@
+package main
+
+import (
+ "bytes"
+ "encoding/binary"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+)
+
+type HandshakeMessage []byte
+
+type Peer struct {
+ conn net.Conn
+ handshake HandshakeMessage
+ ct *ClientTorrent
+ msgCh chan *IncomingMessage
+}
+
+func (p *Peer) PeerIdHexString() string {
+ return hex.EncodeToString(p.handshake[48:])
+}
+
+func (p *Peer) InfoHash() []byte {
+ return p.handshake[28:48]
+}
+
+func (p *Peer) Close() error {
+ err := p.conn.Close()
+ if err != nil {
+ return errors.Join(fmt.Errorf("failed to close peer connection: %s", p.conn.RemoteAddr()), err)
+ }
+
+ return nil
+}
+
+func (p *Peer) DownloadPiece(outFile string, index int) error {
+ if index >= p.ct.Meta.PieceCount() {
+ return nil
+ }
+
+ var data = new(bytes.Buffer)
+ blockLens := p.ct.Meta.BlockLens(index)
+ var blockIndex = 0
+
+ requestFn := func() error {
+ if blockIndex == len(blockLens) {
+ return nil
+ }
+
+ r := RequestPayload{
+ Index: uint32(index),
+ Begin: uint32(blockIndex * BlockSize),
+ Length: blockLens[blockIndex],
+ }
+
+ return p.WriteMessage(MessageTypeRequest, r.Bytes())
+ }
+
+ for {
+ if blockIndex == len(blockLens) {
+ break
+ }
+
+ msg, err := p.ReadMessage()
+ if err != nil {
+ return err
+ }
+
+ switch msg.MessageType {
+ case MessageTypeBitfield:
+ if err != nil {
+ return err
+ }
+
+ err = p.WriteMessage(MessageTypeInterested, nil)
+ if err != nil {
+ return err
+ }
+
+ case MessageTypeUnchoke:
+ if err = requestFn(); err != nil {
+ return err
+ }
+
+ case MessageTypePiece:
+ var block BlockPayload
+
+ _, err = block.Write(msg.Payload)
+ if err != nil {
+ return err
+ }
+
+ _, err = data.Write(block.Block)
+ if err != nil {
+ return err
+ }
+
+ blockIndex++
+
+ if err = requestFn(); err != nil {
+ return err
+ }
+
+ default:
+ return fmt.Errorf("unimplemented message type: %d", msg.MessageType)
+ }
+ }
+
+ if !p.ct.Meta.CheckHash(index, data.Bytes()) {
+ return fmt.Errorf("invalid hash value")
+ }
+
+ out, err := os.Create(outFile)
+ if err != nil {
+ return errors.Join(fmt.Errorf("failed to create file"), err)
+ }
+
+ _, err = out.Write(data.Bytes())
+ if err != nil {
+ return errors.Join(fmt.Errorf("failed to write file"), err)
+ }
+
+ err = out.Close()
+ if err != nil {
+ return errors.Join(fmt.Errorf("failed to close file"), err)
+ }
+
+ return nil
+}
+
+func (p *Peer) ReadMessage() (*IncomingMessage, error) {
+ lenBuf := make([]byte, 4)
+
+ _, err := p.conn.Read(lenBuf)
+ if err != nil {
+ return nil, errors.Join(fmt.Errorf("failed to read message length"), err)
+ }
+
+ msgLen := binary.BigEndian.Uint32(lenBuf) - 1
+ typeBuf := make([]byte, 1)
+
+ _, err = p.conn.Read(typeBuf)
+ if err != nil {
+ return nil, errors.Join(fmt.Errorf("failed to read message type"), err)
+ }
+
+ msgType := MessageType(typeBuf[0])
+ if msgLen == 0 {
+ return &IncomingMessage{
+ Len: msgLen,
+ MessageType: msgType,
+ }, nil
+ }
+
+ payloadBuf := make([]byte, msgLen)
+
+ _, err = io.ReadFull(p.conn, payloadBuf)
+ if err != nil {
+ return nil, errors.Join(fmt.Errorf("failed to read message payload"), err)
+ }
+
+ return &IncomingMessage{
+ Len: msgLen,
+ MessageType: msgType,
+ Payload: payloadBuf,
+ }, nil
+}
+
+func (p *Peer) WriteMessage(t MessageType, payload []byte) error {
+ msg := &OutgoingMessage{
+ MessageType: t,
+ Writer: p.conn,
+ }
+
+ _, err := msg.Write(payload)
+ return err
+}