#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
pt-trades-export.py — Portfolio Tester · export your broker trades to the
cost-basis widget CSV (mission 7.154).

WHAT IT DOES
  Reads your filled trades from a broker (Alpaca or tastytrade), for ONE symbol,
  and writes a CSV in the exact format the "Multi-currency cost basis" widget
  imports:

      date;type;qty;price_usd;fees_usd;rate_eur_usd

  By default it also fills the ECB euro/dollar reference rate (1 EUR = X USD) for
  each trade date, so the import is ready to compute in euros. For a date with no
  ECB quote (weekend / holiday) the last published rate before that date is used.

SECURITY — read this once
  * Your API keys are read from ENVIRONMENT VARIABLES or an interactive prompt.
    They are NEVER taken as command-line arguments (those show up in `ps` and in
    your shell history).
  * This script runs on YOUR machine. Your keys are sent only to the broker you
    chose (over HTTPS) to fetch your own trades, and nothing else. The ECB rate
    request carries no credentials.
  * Nothing is written to disk except the output CSV. No keys, no tokens, no logs.
  * Use read-only / trading-disabled keys if your broker offers them, and revoke
    the keys afterwards if you wish.

REQUIREMENTS
  Python 3.6+ standard library only. No pip install, no dependencies.

USAGE
  List the symbols found, then export one of them.

  Alpaca (paper account shown; drop --paper for live):
      export APCA_API_KEY_ID=...        # or you will be prompted
      export APCA_API_SECRET_KEY=...
      python3 pt-trades-export.py alpaca --paper            # lists symbols
      python3 pt-trades-export.py alpaca --paper --symbol SPY

  tastytrade (production; add --cert for the sandbox):
      export TT_LOGIN=you@example.com   # or you will be prompted
      export TT_PASSWORD=...
      python3 pt-trades-export.py tastytrade               # lists symbols
      python3 pt-trades-export.py tastytrade --symbol AAPL --otp 123456

  Common options:
      --from YYYY-MM-DD  --to YYYY-MM-DD   restrict the date window
      --out FILE.csv                       output file (default <SYM>-operations.csv)
      --currency EUR|GBP|CHF|SEK|DKK|NOK|PLN|CZK   reference currency for the rate
                                           column (default EUR; non-EUR cross-rated)
      --no-rates                           do not fetch ECB rates (leave rate empty)
      --account NUMBER                     tastytrade: pick an account if several
"""

import argparse
import bisect
import csv
import datetime
import getpass
import io
import json
import os
import sys
import urllib.error
import urllib.parse
import urllib.request

CSV_HEADER = "date;type;qty;price_usd;fees_usd;rate_eur_usd"  # widget 7.153 — exact
USER_AGENT = "pt-trades-export/1.0 (+https://portfolio-tester.com)"
TIMEOUT = 30


# --------------------------------------------------------------------------- #
# Small HTTP helper (stdlib only, HTTPS only)
# --------------------------------------------------------------------------- #
def http_json(url, headers=None, method="GET", body=None):
    """GET/POST and return parsed JSON. Raises urllib errors on failure."""
    if not url.lower().startswith("https://"):
        raise ValueError("Refusing non-HTTPS URL: %s" % url)
    data = None
    hdrs = {"User-Agent": USER_AGENT, "Accept": "application/json"}
    if headers:
        hdrs.update(headers)
    if body is not None:
        data = json.dumps(body).encode("utf-8")
        hdrs["Content-Type"] = "application/json"
    req = urllib.request.Request(url, data=data, headers=hdrs, method=method)
    with urllib.request.urlopen(req, timeout=TIMEOUT) as resp:
        raw = resp.read().decode("utf-8")
    return json.loads(raw) if raw else {}


def http_text(url):
    """GET and return the response body as text (for the ECB CSV)."""
    if not url.lower().startswith("https://"):
        raise ValueError("Refusing non-HTTPS URL: %s" % url)
    req = urllib.request.Request(url, headers={"User-Agent": USER_AGENT})
    with urllib.request.urlopen(req, timeout=TIMEOUT) as resp:
        return resp.read().decode("utf-8")


def die(msg, code=1):
    sys.stderr.write("error: %s\n" % msg)
    sys.exit(code)


def env_or_prompt(env_name, prompt_label, secret=False):
    """Credential hygiene: env var first, else interactive prompt. Never argv."""
    val = os.environ.get(env_name)
    if val:
        return val
    try:
        if secret:
            return getpass.getpass("%s: " % prompt_label)
        return input("%s: " % prompt_label).strip()
    except (EOFError, KeyboardInterrupt):
        die("no credentials provided")


# --------------------------------------------------------------------------- #
# A normalised trade: dict {date, type, qty, price, fees, symbol}
#   date  : "YYYY-MM-DD"
#   type  : "buy" | "sell"
#   qty   : str (positive)
#   price : str (per-share USD)
#   fees  : str (USD, >= 0)
# --------------------------------------------------------------------------- #
def _fnum(x):
    try:
        return float(x)
    except (TypeError, ValueError):
        return 0.0


# --------------------------------------------------------------------------- #
# Alpaca
# --------------------------------------------------------------------------- #
def alpaca_fetch_fills(args):
    key = env_or_prompt("APCA_API_KEY_ID", "Alpaca API key id")
    secret = env_or_prompt("APCA_API_SECRET_KEY", "Alpaca API secret key", secret=True)
    base = "https://paper-api.alpaca.markets" if args.paper else "https://api.alpaca.markets"
    headers = {"APCA-API-KEY-ID": key, "APCA-API-SECRET-KEY": secret}

    params = {"activity_types": "FILL", "page_size": "100", "direction": "asc"}
    if args.date_from:
        params["after"] = args.date_from + "T00:00:00Z"
    if args.date_to:
        params["until"] = args.date_to + "T23:59:59Z"

    trades = []
    page_token = None
    while True:
        q = dict(params)
        if page_token:
            q["page_token"] = page_token
        url = base + "/v2/account/activities?" + urllib.parse.urlencode(q)
        batch = http_json(url, headers=headers)
        if not isinstance(batch, list) or not batch:
            break
        for a in batch:
            if a.get("activity_type") != "FILL":
                continue
            ts = a.get("transaction_time", "")
            trades.append({
                "symbol": a.get("symbol", ""),
                "date": ts[:10],
                "type": "sell" if a.get("side") == "sell" else "buy",
                "qty": str(a.get("qty", "")),
                "price": str(a.get("price", "")),
                # A FILL carries no commission; regulatory fees are separate
                # activities (FEE/REG/TAF). v1 reports fees as 0 — documented.
                "fees": "0",
            })
        page_token = batch[-1].get("id")
        if len(batch) < 100 or not page_token:
            break
    return trades


# --------------------------------------------------------------------------- #
# tastytrade
# --------------------------------------------------------------------------- #
def tastytrade_fetch_fills(args):
    base = "https://api.cert.tastyworks.com" if args.cert else "https://api.tastyworks.com"
    login = env_or_prompt("TT_LOGIN", "tastytrade login (email)")
    password = env_or_prompt("TT_PASSWORD", "tastytrade password", secret=True)

    # 1) session token (raw token in Authorization, NOT "Bearer" — that's OAuth)
    sess_headers = {}
    if args.otp:
        sess_headers["X-Tastyworks-OTP"] = args.otp
    try:
        sess = http_json(base + "/sessions", headers=sess_headers, method="POST",
                         body={"login": login, "password": password, "remember-me": False})
    except urllib.error.HTTPError as e:
        die("tastytrade login failed (HTTP %s). Check credentials / OTP." % e.code)
    token = (sess.get("data") or {}).get("session-token")
    if not token:
        die("tastytrade login returned no session token")
    auth = {"Authorization": token}

    # 2) accounts
    accts = http_json(base + "/customers/me/accounts", headers=auth)
    items = (accts.get("data") or {}).get("items") or []
    numbers = [(it.get("account") or {}).get("account-number") for it in items]
    numbers = [n for n in numbers if n]
    if not numbers:
        die("no tastytrade accounts found for this login")
    if args.account:
        if args.account not in numbers:
            die("account %s not found. Available: %s" % (args.account, ", ".join(numbers)))
        account = args.account
    elif len(numbers) == 1:
        account = numbers[0]
    else:
        sys.stderr.write("Multiple accounts — re-run with --account NUMBER:\n")
        for n in numbers:
            sys.stderr.write("  %s\n" % n)
        sys.exit(0)

    # 3) transactions (type=Trade), paginated by page-offset
    trades = []
    page_offset = 0
    while True:
        q = {"type": "Trade", "per-page": "250", "page-offset": str(page_offset), "sort": "Asc"}
        if args.symbol:
            q["symbol"] = args.symbol
        if args.date_from:
            q["start-date"] = args.date_from
        if args.date_to:
            q["end-date"] = args.date_to
        url = base + "/accounts/" + urllib.parse.quote(account) + "/transactions?" + urllib.parse.urlencode(q)
        resp = http_json(url, headers=auth)
        data = resp.get("data") or {}
        rows = data.get("items") or []
        for t in rows:
            if t.get("transaction-type") != "Trade":
                continue
            action = (t.get("action") or "")
            side = "buy" if action.lower().startswith("buy") else "sell" if action.lower().startswith("sell") else None
            if side is None:
                continue
            executed = t.get("executed-at", "")
            fees = (abs(_fnum(t.get("commission")))
                    + abs(_fnum(t.get("regulatory-fees")))
                    + abs(_fnum(t.get("clearing-fees")))
                    + abs(_fnum(t.get("proprietary-index-option-fees"))))
            trades.append({
                "symbol": t.get("symbol", ""),
                "date": executed[:10],
                "type": side,
                "qty": str(t.get("quantity", "")),
                "price": str(t.get("price", "")),
                "fees": ("%.4f" % fees).rstrip("0").rstrip(".") if fees else "0",
            })
        pg = resp.get("pagination") or {}
        total_pages = pg.get("total-pages")
        page_offset += 1
        if not rows or (isinstance(total_pages, int) and page_offset >= total_pages):
            break
    return trades


# --------------------------------------------------------------------------- #
# ECB euro reference exchange rate. Series D.<CCY>.EUR.SP00.A gives <CCY> per EUR.
# For EUR reference (default) the USD series is "1 EUR = X USD" directly. For a
# non-EUR reference currency the script cross-rates: "1 CCY = X USD" = (USD per EUR)
# / (CCY per EUR) on the same date, each side carried forward independently.
# --------------------------------------------------------------------------- #
def ecb_rate_map(date_from, date_to, ccy="USD"):
    """Return {date: <ccy> per EUR} for the window, or None on any failure.

    The caller carries the last published rate forward for missing days."""
    url = ("https://data-api.ecb.europa.eu/service/data/EXR/D.%s.EUR.SP00.A"
           "?startPeriod=%s&endPeriod=%s&format=csvdata"
           % (urllib.parse.quote(ccy), urllib.parse.quote(date_from), urllib.parse.quote(date_to)))
    try:
        text = http_text(url)
    except Exception as e:  # noqa: BLE001 — never fatal, fall back to empty rates
        sys.stderr.write("warning: ECB rate fetch failed for %s (%s) — rate column left empty.\n" % (ccy, e))
        return None
    rates = {}
    reader = csv.DictReader(io.StringIO(text))
    for row in reader:
        d = row.get("TIME_PERIOD")
        v = row.get("OBS_VALUE")
        if d and v:
            try:
                rates[d] = float(v)
            except ValueError:
                pass
    return rates if rates else None


def carry_forward_lookup(rates, sorted_dates, date):
    """Last published rate on or before `date`, or '' if none precedes it."""
    i = bisect.bisect_right(sorted_dates, date)
    if i == 0:
        return ""
    return rates[sorted_dates[i - 1]]


# --------------------------------------------------------------------------- #
# Output
# --------------------------------------------------------------------------- #
def write_csv(path, rows):
    # Plain text join (semicolon, fixed header) — must match the widget byte-for-byte.
    lines = [CSV_HEADER]
    for r in rows:
        lines.append(";".join([
            r["date"], r["type"], r["qty"], r["price"], r["fees"],
            r["rate_eur_usd"],
        ]))
    with open(path, "w", encoding="utf-8", newline="") as f:
        f.write("\n".join(lines) + "\n")


def fmt_rate(v):
    if v == "" or v is None:
        return ""
    return ("%.4f" % v).rstrip("0").rstrip(".") if isinstance(v, float) else str(v)


# --------------------------------------------------------------------------- #
# Main
# --------------------------------------------------------------------------- #
def build_parser():
    p = argparse.ArgumentParser(
        prog="pt-trades-export.py",
        description="Export broker trades to the Portfolio Tester cost-basis CSV. "
                    "Credentials come from environment variables or an interactive "
                    "prompt — never from the command line.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="Keys are sent only to the chosen broker. Nothing but the CSV is "
               "written to disk. See the header of this file for details.",
    )
    sub = p.add_subparsers(dest="broker", required=True)

    common = argparse.ArgumentParser(add_help=False)
    common.add_argument("--symbol", help="ticker to export (omit to list symbols found)")
    common.add_argument("--from", dest="date_from", metavar="YYYY-MM-DD", help="start date")
    common.add_argument("--to", dest="date_to", metavar="YYYY-MM-DD", help="end date")
    common.add_argument("--out", help="output CSV (default <SYMBOL>-operations.csv)")
    common.add_argument("--currency", default="EUR",
                        choices=["EUR", "GBP", "CHF", "SEK", "DKK", "NOK", "PLN", "CZK"],
                        help="reference currency for the rate column (default EUR); "
                             "non-EUR is cross-rated from ECB reference rates")
    common.add_argument("--no-rates", action="store_true", help="do not fetch ECB rates")

    a = sub.add_parser("alpaca", parents=[common], help="Alpaca trade fills")
    a.add_argument("--paper", action="store_true", help="use the paper account endpoint")

    t = sub.add_parser("tastytrade", parents=[common], help="tastytrade trade transactions")
    t.add_argument("--cert", action="store_true", help="use the certification (sandbox) host")
    t.add_argument("--otp", help="one-time passcode if your account uses 2FA")
    t.add_argument("--account", help="account number (if the login has several)")
    return p


def main(argv=None):
    args = build_parser().parse_args(argv)

    sys.stderr.write(
        "pt-trades-export: your keys are read from the environment or a prompt, "
        "sent only to %s (and the ECB for rates), and never written to disk.\n"
        % args.broker
    )

    if args.broker == "alpaca":
        trades = alpaca_fetch_fills(args)
    elif args.broker == "tastytrade":
        trades = tastytrade_fetch_fills(args)
    else:  # pragma: no cover — argparse enforces choices
        die("unknown broker")

    if not trades:
        sys.stderr.write("No trades found.\n")
        # still emit an (empty) CSV if a symbol was requested, so the round-trip is testable
        if args.symbol:
            out = args.out or ("%s-operations.csv" % args.symbol)
            write_csv(out, [])
            sys.stderr.write("Wrote empty %s (header only).\n" % out)
        return 0

    # No symbol → list distinct symbols and stop.
    if not args.symbol:
        counts = {}
        for tr in trades:
            counts[tr["symbol"]] = counts.get(tr["symbol"], 0) + 1
        sys.stderr.write("Symbols found (pass one with --symbol):\n")
        for sym in sorted(counts):
            sys.stderr.write("  %-12s %d trade(s)\n" % (sym, counts[sym]))
        return 0

    rows = [tr for tr in trades if tr["symbol"] == args.symbol]
    if not rows:
        sys.stderr.write("No trades for %s. Run without --symbol to list symbols.\n" % args.symbol)
        out = args.out or ("%s-operations.csv" % args.symbol)
        write_csv(out, [])
        return 0

    # chronological order (stable on equal dates)
    rows.sort(key=lambda r: r["date"])

    # ECB rate enrichment
    missing = 0
    if args.no_rates:
        for r in rows:
            r["rate_eur_usd"] = ""
    else:
        hi = rows[-1]["date"]
        # widen the start by ~10 days so a Monday/holiday trade can borrow the
        # previous published (e.g. Friday) rate.
        try:
            lo_wide = (datetime.datetime.strptime(rows[0]["date"], "%Y-%m-%d")
                       - datetime.timedelta(days=10)).strftime("%Y-%m-%d")
        except ValueError:
            lo_wide = rows[0]["date"]
        usd = ecb_rate_map(lo_wide, hi, "USD")            # USD per EUR
        ref = None if args.currency == "EUR" else ecb_rate_map(lo_wide, hi, args.currency)
        if usd is None or (args.currency != "EUR" and ref is None):
            for r in rows:
                r["rate_eur_usd"] = ""
            missing = len(rows)
        else:
            usd_sd = sorted(usd.keys())
            ref_sd = sorted(ref.keys()) if ref else None
            for r in rows:
                u = carry_forward_lookup(usd, usd_sd, r["date"])
                if args.currency == "EUR":
                    v = u                                  # 1 EUR = u USD
                else:
                    c = carry_forward_lookup(ref, ref_sd, r["date"])
                    # 1 CCY = (USD per EUR) / (CCY per EUR) USD
                    v = (u / c) if (u != "" and c not in ("", 0)) else ""
                if v == "":
                    missing += 1
                r["rate_eur_usd"] = fmt_rate(v)

    out = args.out or ("%s-operations.csv" % args.symbol)
    write_csv(out, rows)

    dmin = rows[0]["date"]
    dmax = rows[-1]["date"]
    sys.stderr.write(
        "Wrote %s — %d trade(s) for %s, %s to %s, %d missing rate(s).\n"
        % (out, len(rows), args.symbol, dmin, dmax, missing)
    )
    if missing and not args.no_rates:
        sys.stderr.write("  (fill the blank rate cells manually in the widget.)\n")
    return 0


if __name__ == "__main__":
    sys.exit(main())
